diff --git a/RELEASING/README.md b/RELEASING/README.md index af001ad6af94..629c8d820ca9 100644 --- a/RELEASING/README.md +++ b/RELEASING/README.md @@ -434,7 +434,7 @@ Create a virtual environment and install the dependencies ```bash cd ${SUPERSET_RELEASE_RC} -python3 -m venv venv +uv venv venv source venv/bin/activate pip install -r requirements/base.txt pip install build twine diff --git a/docker/docker-pytest-entrypoint.sh b/docker/docker-pytest-entrypoint.sh index f155ee4c698b..6f7c4ded723e 100755 --- a/docker/docker-pytest-entrypoint.sh +++ b/docker/docker-pytest-entrypoint.sh @@ -21,7 +21,7 @@ set -e # Wait for PostgreSQL to be ready echo "Waiting for database to be ready..." for i in {1..30}; do - if python3 -c " + if uv run python -c " import psycopg2 try: conn = psycopg2.connect(host='db-light', user='superset', password='superset', database='superset_light') @@ -45,7 +45,7 @@ done if [ "${FORCE_RELOAD}" = "true" ]; then echo "Force reload requested - resetting test database" # Drop and recreate the test database using Python - python3 -c " + uv run python -c " import psycopg2 from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT @@ -82,7 +82,7 @@ else FLAGS="--no-reset-db" # Ensure test database exists using Python - python3 -c " + uv run python -c " import psycopg2 from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT diff --git a/docs/dataframe_subsystem_maturity_report.md b/docs/dataframe_subsystem_maturity_report.md index 121897b0be2e..cdf8483fb12e 100644 --- a/docs/dataframe_subsystem_maturity_report.md +++ b/docs/dataframe_subsystem_maturity_report.md @@ -34,617 +34,3 @@ This stub is kept only to preserve existing links and avoid duplicated documenta -### 1.1 Core Implementation: `SupersetResultSet` - -The `SupersetResultSet` class (located at `superset/result_set.py`) is the primary Arrow integration point in Superset: - -```python -class SupersetResultSet: - def __init__( - self, - data: DbapiResult, - cursor_description: DbapiDescription, - db_engine_spec: type[BaseEngineSpec], - ): - # Converts database results to PyArrow Table - self.table = pa.Table.from_arrays(pa_data, names=column_names) -``` - -Key capabilities: - -- Converts raw DBAPI results to `pa.Table` with automatic type inference. -- Handles edge cases: nested types, temporal types with timezone, large integers. -- Provides `pa_table` property for direct Arrow table access. -- Converts to pandas via `to_pandas_df()` with nullable integer support. - -### 1.2 Type Mapping - -The system maps PyArrow types to Superset's generic types: - -```python -@staticmethod -def convert_pa_dtype(pa_dtype: pa.DataType) -> Optional[str]: - if pa.types.is_boolean(pa_dtype): return "BOOL" - if pa.types.is_integer(pa_dtype): return "INT" - if pa.types.is_floating(pa_dtype): return "FLOAT" - if pa.types.is_string(pa_dtype): return "STRING" - if pa.types.is_temporal(pa_dtype): return "DATETIME" - return None -``` - -### 1.3 Arrow IPC Support - -Arrow IPC (Inter-Process Communication) is used for efficient result serialization in SQL Lab: - -```python -# superset/sqllab/utils.py -def write_ipc_buffer(table: pa.Table) -> pa.Buffer: - sink = pa.BufferOutputStream() - with pa.ipc.new_stream(sink, table.schema) as writer: - writer.write_table(table) - return sink.getvalue() -``` - -This enables: - -- Zero-copy deserialization of cached query results. -- Efficient transfer of large datasets. -- Type-safe serialization with full schema preservation. - -### 1.4 CSV Upload with PyArrow Engine - -The file upload system can use PyArrow for CSV parsing: - -```python -# superset/commands/database/uploaders/csv_reader.py -def _get_csv_engine() -> Literal["c", "pyarrow"]: - # Uses pyarrow engine when available for faster parsing - pyarrow_spec = util.find_spec("pyarrow") - if pyarrow_spec: - return "pyarrow" - return "c" -``` - -## 2. Pandas DataFrame Processing Pipeline - -### 2.1 Post-Processing Functions - -Superset provides a rich set of DataFrame transformations in `superset/utils/pandas_postprocessing/`: - -| Function | Description | -| --- | --- | -| `aggregate` | Apply aggregation functions | -| `pivot` | Reshape data with pivot operations | -| `rolling` | Rolling window calculations | -| `resample` | Time-series resampling | -| `contribution` | Calculate percentage contributions | -| `compare` | Period-over-period comparisons | -| `boxplot` | Statistical boxplot calculations | -| `histogram` | Histogram bin calculations | -| `prophet` | Time-series forecasting (optional) | - -### 2.2 Query Result Processing - -The `QueryContextProcessor` manages DataFrame handling for chart rendering: - -```python -# superset/common/query_context_processor.py -class QueryContextProcessor: - cache_type: ClassVar[str] = "df" # Caches pandas DataFrames - - def get_df_payload(self, query_obj: QueryObject) -> dict[str, Any]: - # Returns cached or fresh DataFrame results -``` - -### 2.3 DataFrame Utilities - -Additional utilities in `superset/common/utils/dataframe_utils.py`: - -- `left_join_df`: DataFrame join operations. -- `full_outer_join_df`: Full outer joins. -- `df_metrics_to_num`: Metric type coercion. -- `is_datetime_series`: Type detection. - -## 3. Obstacles to DataFrame Ingestion - -### 3.1 Current Data Flow Architecture - -```text -┌─────────────────┐ ┌──────────────┐ ┌────────────────┐ -│ Database │ │ SupersetRS │ │ Pandas DF │ -│ (SQL Query) │───▶│ (PyArrow) │───▶│ (Processing) │ -└─────────────────┘ └──────────────┘ └────────────────┘ - │ - ▼ - ┌────────────────┐ - │ Chart/Dashboard│ - │ (JSON Response)│ - └────────────────┘ -``` - -Problem: There is no direct path from external DataFrames to the visualization layer without going through a database. - -### 3.2 Identified Obstacles - -**Obstacle 1: Database-Centric Design** - -The Explorable protocol (data source interface) assumes SQL-based querying: - -```python -# superset/explorables/base.py -class Explorable(Protocol): - def get_query_result(self, query_object: QueryObject) -> QueryResult: - """Execute a query and return results.""" - - def get_query_str(self, query_obj: QueryObjectDict) -> str: - """Get the query string without executing.""" -``` - -Impact: Cannot directly query DataFrame objects without SQL translation. - -**Obstacle 2: Dataset Registration Requirement** - -All data sources must be registered as "Datasets" with: - -- Database connection. -- Table/view reference. -- Column metadata in database. - -Impact: External DataFrames require upload to a database first. - -**Obstacle 3: Security Model Assumptions** - -Row-Level Security (RLS) and permissions are tied to database objects: - -```python -@property -def perm(self) -> str: - """Permission string for this explorable.""" - # Format: "[database].[schema].[table]" -``` - -Impact: Permission model doesn't accommodate in-memory data. - -**Obstacle 4: Caching Layer Design** - -The cache layer assumes cacheable, reproducible SQL queries: - -```python -def query_cache_key(self, query_obj: QueryObject) -> str: - # Cache key based on SQL query hash -``` - -Impact: In-memory DataFrames don't have stable cache keys. - -### 3.3 Ergonomic Improvements Needed - -| Area | Current State | Needed Improvement | -| --- | --- | --- | -| Data Source Protocol | SQL-only | Add DataFrame adapter | -| Dataset Model | Database-bound | Support virtual datasets | -| Column Metadata | From database schema | Allow explicit schema definition | -| Query Execution | SQL engine required | Allow DataFrame operations | -| Cache Keys | SQL-based | Support content-addressable hashing | -| Permissions | Database-centric | Add dataset-level permissions | - -## 4. MCP Service Architecture Review - -### 4.1 Current MCP Capabilities - -The existing MCP service (`superset/mcp_service/`) provides: - -- Tools: `execute_sql`, `generate_chart`, `generate_dashboard`, `list_datasets`, etc. -- Resources: Instance metadata, schema discovery. -- Prompts: Guided workflows. -- Auth: JWT and development mode authentication. -- Caching: Redis-backed for multi-pod deployments. - -### 4.2 Current Dashboard Creation Flow - -```text -┌─────────────┐ ┌──────────────┐ ┌────────────────┐ -│ MCP Client │ │ execute_sql │ │ Database │ -│ (LLM Agent) │───▶│ Tool │───▶│ Query │ -└─────────────┘ └──────────────┘ └────────────────┘ - │ │ - │ ┌──────────────┐ │ - └─────────▶│generate_chart│◀────────────┘ - │ Tool │ (via dataset_id) - └──────────────┘ - │ - ▼ - ┌──────────────────┐ - │ generate_dashboard│ - │ Tool │ - └──────────────────┘ -``` - -### 4.3 Gap Analysis - -| Current Capability | Gap for DataFrame Ingestion | -| --- | --- | -| `execute_sql` returns DataFrame | No way to register as reusable dataset | -| `generate_chart` needs `dataset_id` | No dataset_id for ephemeral data | -| `list_datasets` shows DB tables | No ephemeral dataset discovery | -| Schema validation | No schema definition from DataFrame | - -## 5. Plan: Fast MCP DataFrame Interface for Dashboards - -### 5.1 Proposed Architecture - -```text -┌─────────────────────────────────────────────────────────────────┐ -│ MCP DataFrame Interface │ -├─────────────────────────────────────────────────────────────────┤ -│ │ -│ ┌──────────────┐ ┌────────────────┐ ┌────────────────┐ │ -│ │ DataFrame │ │ Virtual Dataset │ │ Chart Builder │ │ -│ │ Ingestion │───▶│ Registry │───▶│ (Temporary) │ │ -│ └──────────────┘ └────────────────┘ └────────────────┘ │ -│ │ │ │ │ -│ │ Arrow IPC │ Schema │ Config │ -│ ▼ ▼ ▼ │ -│ ┌──────────────┐ ┌────────────────┐ ┌────────────────┐ │ -│ │ Arrow Table │ │ Column Metadata│ │ Superset Chart │ │ -│ └──────────────┘ └────────────────┘ └────────────────┘ │ -│ │ │ -│ ▼ │ -│ ┌────────────────────────────────┐ │ -│ │ Dashboard Generator │ │ -│ │ (Ephemeral/Persistent) │ │ -│ └────────────────────────────────┘ │ -│ │ -└─────────────────────────────────────────────────────────────────┘ -``` - -### 5.2 Implementation Plan - -**Phase 1: Virtual Dataset Registry (Week 1-2)** - -Objective: Create an in-memory registry for DataFrame-based datasets. - -```python -# Current: superset/mcp_service/dataframe/registry.py -from dataclasses import dataclass -from typing import Dict, List -import pyarrow as pa -from datetime import datetime, timedelta - -@dataclass -class VirtualDataset: - """Represents a DataFrame-based virtual dataset.""" - id: str # UUID for the virtual dataset - name: str - schema: pa.Schema - table: pa.Table - created_at: datetime - ttl: timedelta # Time-to-live for cleanup - owner_session: str # MCP session that created it - -class VirtualDatasetRegistry: - """In-memory registry for virtual datasets.""" - - def register(self, name: str, table: pa.Table, ttl: timedelta) -> str: - """Register a DataFrame as a virtual dataset.""" - - def get(self, dataset_id: str) -> VirtualDataset | None: - """Retrieve a virtual dataset.""" - - def query(self, dataset_id: str, query_obj: QueryObject) -> pa.Table: - """Execute a query against a virtual dataset using DuckDB.""" - - def cleanup_expired(self) -> int: - """Remove expired datasets.""" -``` - -Key features: - -- Session-scoped datasets (cleaned up when session ends). -- Optional persistence for sharing across sessions. -- DuckDB integration for SQL queries against Arrow tables. - -**Phase 2: DataFrame Ingestion MCP Tool (Week 2-3)** - -Objective: Create MCP tools for DataFrame ingestion. - -```python -# Current: superset/mcp_service/dataframe/tool/ingest_dataframe.py -from pydantic import BaseModel, Field -from typing import List, Optional -import base64 - -class ColumnSchema(BaseModel): - """Column definition for DataFrame ingestion.""" - name: str - type: str # Arrow type string - is_temporal: bool = False - is_metric: bool = False - -class IngestDataFrameRequest(BaseModel): - """Request to ingest a DataFrame via MCP.""" - name: str = Field(description="Display name for the dataset") - data: str = Field(description="Base64-encoded Arrow IPC stream") - schema: Optional[List[ColumnSchema]] = Field( - default=None, - description="Optional explicit schema (inferred if not provided)") - ttl_minutes: int = Field( - default=60, - description="Time-to-live in minutes (0 for session lifetime)") - -@tool(tags=["dataframe", "mutate"]) -async def ingest_dataframe( - request: IngestDataFrameRequest, - ctx: Context -) -> IngestDataFrameResponse: - """ - Ingest a DataFrame from Arrow IPC format for visualization. - - This tool allows AI agents to upload DataFrame data directly - without requiring database storage. The data is registered as - a virtual dataset that can be used with generate_chart. - - Example usage: - ```python - import pyarrow as pa - import base64 - - # Create Arrow table - table = pa.table({'x': [1, 2, 3], 'y': [4, 5, 6]}) - - # Serialize to IPC - sink = pa.BufferOutputStream() - with pa.ipc.new_stream(sink, table.schema) as writer: - writer.write_table(table) - data = base64.b64encode(sink.getvalue().to_pybytes()).decode() - - # Ingest via MCP - result = await ingest_dataframe(IngestDataFrameRequest( - name="my_analysis", - data=data - )) - - # Use virtual_dataset_id with generate_chart - chart = await generate_chart(GenerateChartRequest( - dataset_id=f"virtual:{result.dataset_id}", - config={...} - )) - ``` - """ -``` - -**Phase 3: Virtual Dataset Query Adapter (Week 3-4)** - -Objective: Create an Explorable implementation for virtual datasets. - -```python -# Proposed: superset/mcp_service/dataframe/virtual_explorable.py -from superset.explorables.base import Explorable -import duckdb - -class VirtualDataFrameExplorable: - """Explorable implementation for in-memory DataFrames.""" - - def __init__(self, virtual_dataset: VirtualDataset): - self._dataset = virtual_dataset - self._duckdb = duckdb.connect() - # Register Arrow table with DuckDB for SQL queries - self._duckdb.register("data", virtual_dataset.table) - - def get_query_result(self, query_obj: QueryObject) -> QueryResult: - """Execute query using DuckDB against Arrow table.""" - # Translate QueryObject to SQL - sql = self._build_sql(query_obj) - # Execute against DuckDB - result = self._duckdb.execute(sql).arrow() - return QueryResult(df=result.to_pandas(), ...) - - @property - def columns(self) -> list[Any]: - """Return column metadata from Arrow schema.""" - return [ - { - "column_name": field.name, - "type": str(field.type), - "is_dttm": pa.types.is_temporal(field.type), - } - for field in self._dataset.schema - ] -``` - -**Phase 4: Chart Generation Integration (Week 4-5)** - -Objective: Extend `generate_chart` to support virtual datasets. - -```python -# Modifications to superset/mcp_service/chart/tool/generate_chart.py - -async def generate_chart(request: GenerateChartRequest, ctx: Context): - """Extended to support virtual datasets.""" - - dataset_id = request.dataset_id - - # Check if this is a virtual dataset reference - if isinstance(dataset_id, str) and dataset_id.startswith("virtual:"): - virtual_id = dataset_id[8:] # Remove "virtual:" prefix - - # Get virtual dataset from registry - from superset.mcp_service.dataframe.registry import get_registry - registry = get_registry() - virtual_dataset = registry.get(virtual_id) - - if not virtual_dataset: - return error_response("Virtual dataset not found or expired") - - # Create chart with virtual explorable - explorable = VirtualDataFrameExplorable(virtual_dataset) - # ... continue with chart generation using explorable - else: - # Existing database dataset logic - ... -``` - -**Phase 5: Fast Dashboard Pipeline (Week 5-6)** - -Objective: Create an end-to-end pipeline for DataFrame-to-Dashboard. - -```python -# Proposed: superset/mcp_service/dataframe/tool/create_dashboard_from_dataframe.py - -class CreateDashboardFromDataFrameRequest(BaseModel): - """Create a complete dashboard from DataFrame data.""" - name: str - data: str # Base64 Arrow IPC - charts: List[ChartSpec] - layout: Optional[LayoutSpec] = None - auto_suggest: bool = Field( - default=True, - description="Auto-suggest charts based on data analysis") - -@tool(tags=["dataframe", "dashboard", "mutate"]) -async def create_dashboard_from_dataframe( - request: CreateDashboardFromDataFrameRequest, - ctx: Context -) -> CreateDashboardFromDataFrameResponse: - """ - Create a complete dashboard directly from DataFrame data. - - This is a high-level tool that combines: - 1. DataFrame ingestion - 2. Chart generation - 3. Dashboard assembly - - Into a single operation for maximum efficiency. - - When auto_suggest=True, the tool analyzes the data and - suggests appropriate visualizations based on: - - Column types (temporal, categorical, numeric) - - Data cardinality - - Statistical properties - """ -``` - -### 5.3 DuckDB Integration for DataFrame Queries - -Why DuckDB? - -DuckDB provides crucial capabilities: - -- Zero-copy Arrow Integration: Native support for querying Arrow tables. -- SQL Compatibility: Familiar SQL syntax for complex queries. -- In-Process: No network overhead for queries. -- Rich Analytics: Window functions, aggregations, joins. - -Example DuckDB usage for virtual datasets: - -```python -import duckdb -import pyarrow as pa - -def query_arrow_table(table: pa.Table, sql: str) -> pa.Table: - """Execute SQL against Arrow table using DuckDB.""" - conn = duckdb.connect() - conn.register("df", table) - - # DuckDB returns Arrow directly - result = conn.execute(sql).arrow() - return result -``` - -### 5.4 Security Considerations - -| Concern | Mitigation | -| --- | --- | -| Resource exhaustion | TTL-based cleanup, size limits | -| Data isolation | Session-scoped datasets by default | -| Permission bypass | Virtual datasets inherit session permissions | -| Memory limits | Configurable max dataset size per session | -| Data leakage | Auto-cleanup on session end | - -### 5.5 Configuration Options - -```python -# Proposed config additions for superset_config.py - -# Virtual Dataset Configuration -MCP_VIRTUAL_DATASET_ENABLED = True -MCP_VIRTUAL_DATASET_MAX_SIZE_MB = 100 # Max size per dataset -MCP_VIRTUAL_DATASET_MAX_COUNT = 10 # Max datasets per session -MCP_VIRTUAL_DATASET_DEFAULT_TTL_MINUTES = 60 -MCP_VIRTUAL_DATASET_STORAGE_BACKEND = "memory" # or "redis" for multi-pod -``` - -## 6. Implementation Priority Matrix - -| Phase | Effort | Impact | Priority | -| --- | --- | --- | --- | -| 1. Virtual Dataset Registry | Medium | High | P0 | -| 2. DataFrame Ingestion Tool | Medium | High | P0 | -| 3. DuckDB Query Adapter | High | High | P1 | -| 4. Chart Generation Integration | Medium | High | P1 | -| 5. Fast Dashboard Pipeline | Medium | Medium | P2 | -| 6. Auto-Chart Suggestion | High | Medium | P3 | - -## 7. Success Metrics - -### Performance Targets - -- DataFrame ingestion: < 100ms for 1M rows. -- Chart generation from virtual dataset: < 500ms. -- Full dashboard creation: < 2s for 5 charts. - -### Adoption Metrics - -- Number of virtual datasets created per day. -- Conversion rate: virtual → persisted datasets. -- MCP tool usage frequency. - -## 8. Risks and Mitigations - -| Risk | Impact | Probability | Mitigation | -| --- | --- | --- | --- | -| Memory pressure from large datasets | High | Medium | Strict size limits, spill to disk | -| Session state complexity | Medium | High | Simple TTL-based cleanup | -| DuckDB compatibility issues | Medium | Low | Fallback to pandas processing | -| Security vulnerabilities | High | Low | Sandboxed execution, input validation | - -## 9. Future Considerations - -### DataFusion Integration - -Apache DataFusion could replace DuckDB for potential benefits: - -- Native Rust performance. -- Better Arrow integration. -- Async query execution. - -### Streaming DataFrame Support - -For very large datasets: - -- Chunked ingestion via Arrow `RecordBatch` streams. -- Progressive visualization rendering. -- Out-of-core processing. - -### Federated Queries - -Combining virtual datasets with database datasets: - -- Join in-memory data with database tables. -- Unified query optimization. - -## 10. Conclusion - -Apache Superset has a solid foundation for DataFrame/Arrow support at its core, but requires new components to fully realize DataFrame ingestion capabilities. The proposed MCP DataFrame interface provides a pragmatic path to enabling AI agents to create dashboards directly from DataFrame data without requiring database round-trips. - -The phased implementation approach allows for: - -- Quick wins with basic ingestion support. -- Progressive enhancement of query capabilities. -- Future-proof architecture for streaming and federation. - -Recommended next steps: - -- Implement Phase 1 (Virtual Dataset Registry) as a proof-of-concept. -- Validate DuckDB query performance with real-world data. -- Gather feedback from MCP service users on API design. -- Iterate on the implementation based on usage patterns. diff --git a/docs/developer_portal/contributing/development-setup.md b/docs/developer_portal/contributing/development-setup.md index a543eb244429..7bc989ef4b63 100644 --- a/docs/developer_portal/contributing/development-setup.md +++ b/docs/developer_portal/contributing/development-setup.md @@ -382,7 +382,7 @@ Ensure that you are using Python version 3.9, 3.10 or 3.11, then proceed with: ```bash # Create a virtual environment and activate it (recommended) -python3 -m venv venv # setup a python3 virtualenv +uv venv venv # setup a python3 virtualenv source venv/bin/activate # Install external dependencies @@ -414,7 +414,7 @@ Or you can install it via our Makefile ```bash # Create a virtual environment and activate it (recommended) -$ python3 -m venv venv # setup a python3 virtualenv +$ uv venv venv # setup a python3 virtualenv $ source venv/bin/activate # install pip packages + pre-commit @@ -438,9 +438,9 @@ If you have made changes to the FAB-managed templates, which are not built the s If you add a new requirement or update an existing requirement (per the `install_requires` section in `setup.py`) you must recompile (freeze) the Python dependencies to ensure that for CI, testing, etc. the build is deterministic. This can be achieved via, ```bash -python3 -m venv venv +uv venv venv source venv/bin/activate -python3 -m pip install -r requirements/development.txt +uv pip install -r requirements/development.txt ./scripts/uv-pip-compile.sh ``` @@ -912,7 +912,7 @@ root 10 6 7 14:09 ? 00:00:07 /usr/local/bin/python /usr/bin/f Inject debugpy into the running Flask process. In this case PID 6. ```bash -python3 -m debugpy --listen 0.0.0.0:5678 --pid 6 +uv run python -m debugpy --listen 0.0.0.0:5678 --pid 6 ``` Verify that debugpy is listening on port 5678 diff --git a/docs/developer_portal/contributing/howtos.md b/docs/developer_portal/contributing/howtos.md index e4469dfb07af..a563d63f11de 100644 --- a/docs/developer_portal/contributing/howtos.md +++ b/docs/developer_portal/contributing/howtos.md @@ -498,7 +498,7 @@ npm install # Recreate virtual environment deactivate rm -rf venv -python3 -m venv venv +uv venv venv source venv/bin/activate pip install -r requirements/development.txt pip install -e . diff --git a/docs/docs/contributing/development.mdx b/docs/docs/contributing/development.mdx index b75f2d5118a4..7c3d9a9df923 100644 --- a/docs/docs/contributing/development.mdx +++ b/docs/docs/contributing/development.mdx @@ -328,7 +328,7 @@ Ensure that you are using Python version 3.9, 3.10 or 3.11, then proceed with: ```bash # Create a virtual environment and activate it (recommended) -python3 -m venv venv # setup a python3 virtualenv +uv venv venv # setup a python3 virtualenv source venv/bin/activate # Install external dependencies @@ -366,7 +366,7 @@ Or you can install it via our Makefile ```bash # Create a virtual environment and activate it (recommended) -$ python3 -m venv venv # setup a python3 virtualenv +$ uv venv venv # setup a python3 virtualenv $ source venv/bin/activate # install pip packages + pre-commit @@ -390,9 +390,9 @@ If you have made changes to the FAB-managed templates, which are not built the s If you add a new requirement or update an existing requirement (per the `install_requires` section in `setup.py`) you must recompile (freeze) the Python dependencies to ensure that for CI, testing, etc. the build is deterministic. This can be achieved via, ```bash -python3 -m venv venv +uv venv venv source venv/bin/activate -python3 -m pip install -r requirements/development.txt +uv pip install -r requirements/development.txt ./scripts/uv-pip-compile.sh ``` @@ -955,7 +955,7 @@ root 10 6 7 14:09 ? 00:00:07 /usr/local/bin/python /usr/bin/f Inject debugpy into the running Flask process. In this case PID 6. ```bash -python3 -m debugpy --listen 0.0.0.0:5678 --pid 6 +uv run python -m debugpy --listen 0.0.0.0:5678 --pid 6 ``` Verify that debugpy is listening on port 5678 diff --git a/docs/docs/contributing/howtos.mdx b/docs/docs/contributing/howtos.mdx index 37eb61ad817e..be3f656f8478 100644 --- a/docs/docs/contributing/howtos.mdx +++ b/docs/docs/contributing/howtos.mdx @@ -442,7 +442,7 @@ root 10 6 7 14:09 ? 00:00:07 /usr/local/bin/python /usr/bin/f Inject debugpy into the running Flask process. In this case PID 6. ```bash -python3 -m debugpy --listen 0.0.0.0:5678 --pid 6 +uv run python -m debugpy --listen 0.0.0.0:5678 --pid 6 ``` Verify that debugpy is listening on port 5678 diff --git a/docs/docs/installation/pypi.mdx b/docs/docs/installation/pypi.mdx index 14228c208121..b45b0c669b75 100644 --- a/docs/docs/installation/pypi.mdx +++ b/docs/docs/installation/pypi.mdx @@ -108,7 +108,7 @@ You can create and activate a virtual environment using the following commands. ```bash # virtualenv is shipped in Python 3.6+ as venv instead of pyvenv. # See https://docs.python.org/3.6/library/venv.html -python3 -m venv venv +uv venv venv . venv/bin/activate ``` diff --git a/docs/package.json b/docs/package.json index e41f3cee0480..c42a97b6075e 100644 --- a/docs/package.json +++ b/docs/package.json @@ -10,7 +10,7 @@ "start:quick": "yarn run _init && NODE_OPTIONS='--max-old-space-size=8192' NODE_ENV=development docusaurus start", "stop": "pkill -f 'docusaurus start' || pkill -f 'docusaurus serve' || echo 'No docusaurus server running'", "build": "yarn run _init && yarn run generate:all && NODE_OPTIONS='--max-old-space-size=8192' DEBUG=docusaurus:* docusaurus build", - "generate:api-docs": "python3 scripts/fix-openapi-spec.py && docusaurus gen-api-docs superset && node scripts/convert-api-sidebar.mjs && node scripts/generate-api-index.mjs && node scripts/generate-api-tag-pages.mjs", + "generate:api-docs": "uv run python scripts/fix-openapi-spec.py && docusaurus gen-api-docs superset && node scripts/convert-api-sidebar.mjs && node scripts/generate-api-index.mjs && node scripts/generate-api-tag-pages.mjs", "clean:api-docs": "docusaurus clean-api-docs superset", "swizzle": "docusaurus swizzle", "deploy": "docusaurus deploy", @@ -24,8 +24,8 @@ "generate:database-docs": "node scripts/generate-database-docs.mjs", "gen-db-docs": "node scripts/generate-database-docs.mjs", "generate:all": "yarn run generate:extension-components & yarn run generate:superset-components & yarn run generate:database-docs & wait && yarn run generate:api-docs", - "lint:db-metadata": "python3 ../superset/db_engine_specs/lint_metadata.py", - "lint:db-metadata:report": "python3 ../superset/db_engine_specs/lint_metadata.py --markdown -o ../superset/db_engine_specs/METADATA_STATUS.md", + "lint:db-metadata": "uv run python ../superset/db_engine_specs/lint_metadata.py", + "lint:db-metadata:report": "uv run python ../superset/db_engine_specs/lint_metadata.py --markdown -o ../superset/db_engine_specs/METADATA_STATUS.md", "update:readme-db-logos": "node scripts/generate-database-docs.mjs --update-readme", "eslint": "eslint .", "version:add": "node scripts/manage-versions.mjs add", diff --git a/docs/scripts/generate-database-docs.mjs b/docs/scripts/generate-database-docs.mjs index 912569294ee1..c66bb927c8c8 100644 --- a/docs/scripts/generate-database-docs.mjs +++ b/docs/scripts/generate-database-docs.mjs @@ -339,7 +339,7 @@ print(json.dumps(debug_info), file=sys.stderr) print(json.dumps(databases, default=str)) `; - const result = spawnSync('python3', ['-c', pythonCode], { + const result = spawnSync('uv', ['run', 'python', '-c', pythonCode], { cwd: ROOT_DIR, encoding: 'utf-8', timeout: 30000, @@ -767,7 +767,7 @@ function extractCustomErrors() { try { const scriptPath = path.join(__dirname, 'extract_custom_errors.py'); - const result = spawnSync('python3', [scriptPath], { + const result = spawnSync('uv', ['run', 'python', scriptPath], { cwd: ROOT_DIR, encoding: 'utf-8', timeout: 30000, diff --git a/docs/versioned_docs/version-6.0.0/contributing/development.mdx b/docs/versioned_docs/version-6.0.0/contributing/development.mdx index 8e6822adc5b4..6d9c987d9ec0 100644 --- a/docs/versioned_docs/version-6.0.0/contributing/development.mdx +++ b/docs/versioned_docs/version-6.0.0/contributing/development.mdx @@ -328,7 +328,7 @@ Ensure that you are using Python version 3.9, 3.10 or 3.11, then proceed with: ```bash # Create a virtual environment and activate it (recommended) -python3 -m venv venv # setup a python3 virtualenv +uv venv venv # setup a python3 virtualenv source venv/bin/activate # Install external dependencies @@ -360,7 +360,7 @@ Or you can install it via our Makefile ```bash # Create a virtual environment and activate it (recommended) -$ python3 -m venv venv # setup a python3 virtualenv +$ uv venv venv # setup a python3 virtualenv $ source venv/bin/activate # install pip packages + pre-commit @@ -384,9 +384,9 @@ If you have made changes to the FAB-managed templates, which are not built the s If you add a new requirement or update an existing requirement (per the `install_requires` section in `setup.py`) you must recompile (freeze) the Python dependencies to ensure that for CI, testing, etc. the build is deterministic. This can be achieved via, ```bash -python3 -m venv venv +uv venv venv source venv/bin/activate -python3 -m pip install -r requirements/development.txt +uv pip install -r requirements/development.txt ./scripts/uv-pip-compile.sh ``` @@ -838,7 +838,7 @@ root 10 6 7 14:09 ? 00:00:07 /usr/local/bin/python /usr/bin/f Inject debugpy into the running Flask process. In this case PID 6. ```bash -python3 -m debugpy --listen 0.0.0.0:5678 --pid 6 +uv run python -m debugpy --listen 0.0.0.0:5678 --pid 6 ``` Verify that debugpy is listening on port 5678 diff --git a/docs/versioned_docs/version-6.0.0/contributing/howtos.mdx b/docs/versioned_docs/version-6.0.0/contributing/howtos.mdx index b592c630e2da..58b4ae6ad133 100644 --- a/docs/versioned_docs/version-6.0.0/contributing/howtos.mdx +++ b/docs/versioned_docs/version-6.0.0/contributing/howtos.mdx @@ -356,7 +356,7 @@ root 10 6 7 14:09 ? 00:00:07 /usr/local/bin/python /usr/bin/f Inject debugpy into the running Flask process. In this case PID 6. ```bash -python3 -m debugpy --listen 0.0.0.0:5678 --pid 6 +uv run python -m debugpy --listen 0.0.0.0:5678 --pid 6 ``` Verify that debugpy is listening on port 5678 diff --git a/docs/versioned_docs/version-6.0.0/installation/pypi.mdx b/docs/versioned_docs/version-6.0.0/installation/pypi.mdx index 14228c208121..b45b0c669b75 100644 --- a/docs/versioned_docs/version-6.0.0/installation/pypi.mdx +++ b/docs/versioned_docs/version-6.0.0/installation/pypi.mdx @@ -108,7 +108,7 @@ You can create and activate a virtual environment using the following commands. ```bash # virtualenv is shipped in Python 3.6+ as venv instead of pyvenv. # See https://docs.python.org/3.6/library/venv.html -python3 -m venv venv +uv venv venv . venv/bin/activate ``` diff --git a/superset-frontend/eslint-rules/eslint-plugin-i18n-strings/index.js b/superset-frontend/eslint-rules/eslint-plugin-i18n-strings/index.js index 7cb97fea283f..1426f506d820 100644 --- a/superset-frontend/eslint-rules/eslint-plugin-i18n-strings/index.js +++ b/superset-frontend/eslint-rules/eslint-plugin-i18n-strings/index.js @@ -30,19 +30,26 @@ module.exports = { rules: { 'no-template-vars': { + meta: { + type: 'problem', + docs: { + description: 'Disallow variables in translation template strings', + }, + schema: [], + }, create(context) { function handler(node) { - if (node.arguments.length) { - const firstArgs = node.arguments[0]; + for (const arg of node.arguments ?? []) { if ( - firstArgs.type === 'TemplateLiteral' && - firstArgs.expressions.length + arg.type === 'TemplateLiteral' && + (arg.expressions?.length ?? 0) > 0 ) { context.report({ node, message: "Don't use variables in translation string templates. Flask-babel is a static translation service, so it can't handle strings that include variables", }); + break; } } } @@ -53,6 +60,13 @@ module.exports = { }, }, 'sentence-case-buttons': { + meta: { + type: 'suggestion', + docs: { + description: 'Enforce sentence case for button text in translations', + }, + schema: [], + }, create(context) { function isTitleCase(str) { // Match "Delete Dataset", "Create Chart", etc. (2+ title-cased words) @@ -60,12 +74,12 @@ module.exports = { } function isButtonContext(node) { - const { parent } = node; + const parent = node.parent; if (!parent) return false; // Check for button-specific props if (parent.type === 'Property') { - const key = parent.key.name; + const key = parent.key?.name; return [ 'primaryButtonName', 'secondaryButtonName', @@ -75,10 +89,10 @@ module.exports = { } // Check for Button components - if (parent.type === 'JSXExpressionContainer') { + if (String(parent.type) === 'JSXExpressionContainer') { const jsx = parent.parent; - if (jsx?.type === 'JSXElement') { - const elementName = jsx.openingElement.name.name; + if (String(jsx?.type) === 'JSXElement') { + const elementName = jsx?.openingElement?.name?.name; return elementName === 'Button'; } } @@ -87,23 +101,20 @@ module.exports = { } function handler(node) { - if (node.arguments.length) { - const firstArg = node.arguments[0]; - if ( - firstArg.type === 'Literal' && - typeof firstArg.value === 'string' - ) { - const text = firstArg.value; + for (const arg of node.arguments ?? []) { + if (arg.type !== 'Literal' || typeof arg.value !== 'string') { + continue; + } + const text = arg.value; - if (isButtonContext(node) && isTitleCase(text)) { - const sentenceCase = text - .toLowerCase() - .replace(/^\w/, c => c.toUpperCase()); - context.report({ - node: firstArg, - message: `Button text should use sentence case: "${text}" should be "${sentenceCase}"`, - }); - } + if (isButtonContext(node) && isTitleCase(text)) { + const sentenceCase = text + .toLowerCase() + .replace(/^\w/, c => c.toUpperCase()); + context.report({ + node: arg, + message: `Button text should use sentence case: "${text}" should be "${sentenceCase}"`, + }); } } } diff --git a/superset-frontend/eslint-rules/eslint-plugin-icons/index.js b/superset-frontend/eslint-rules/eslint-plugin-icons/index.js index a3dae1a3c104..763c073b288b 100644 --- a/superset-frontend/eslint-rules/eslint-plugin-icons/index.js +++ b/superset-frontend/eslint-rules/eslint-plugin-icons/index.js @@ -47,12 +47,20 @@ module.exports = { node.openingElement && node.openingElement.name.name === 'i' && node.openingElement.attributes && - node.openingElement.attributes.some( - attr => - attr.name && - attr.name.name === 'className' && - /fa fa-/.test(attr.value.value), - ) + node.openingElement.attributes.some(attr => { + if (attr.name?.name !== 'className') { + return false; + } + // className="fa fa-home" + if (attr.value?.type === 'Literal') { + return /fa fa-/.test(attr.value.value ?? ''); + } + // className={'fa fa-home'} + if (attr.value?.type === 'JSXExpressionContainer') { + return /fa fa-/.test(attr.value.expression?.value ?? ''); + } + return false; + }) ) { context.report({ node, diff --git a/superset-frontend/eslint-rules/eslint-plugin-icons/no-fontawesome.test.js b/superset-frontend/eslint-rules/eslint-plugin-icons/no-fontawesome.test.js index 52a81eabe9ea..341cc3e04176 100644 --- a/superset-frontend/eslint-rules/eslint-plugin-icons/no-fontawesome.test.js +++ b/superset-frontend/eslint-rules/eslint-plugin-icons/no-fontawesome.test.js @@ -28,7 +28,9 @@ const plugin = require('.'); //------------------------------------------------------------------------------ // Tests //------------------------------------------------------------------------------ -const ruleTester = new RuleTester({ parserOptions: { ecmaVersion: 6 } }); +const ruleTester = new RuleTester({ + parserOptions: { ecmaVersion: 6, ecmaFeatures: { jsx: true } }, +}); const rule = plugin.rules['no-fa-icons-usage']; const errors = [ diff --git a/superset-frontend/eslint-rules/eslint-plugin-theme-colors/index.js b/superset-frontend/eslint-rules/eslint-plugin-theme-colors/index.js index ed1bd4d392e1..c23d5a3241fc 100644 --- a/superset-frontend/eslint-rules/eslint-plugin-theme-colors/index.js +++ b/superset-frontend/eslint-rules/eslint-plugin-theme-colors/index.js @@ -25,35 +25,26 @@ const COLOR_KEYWORDS = require('./colors'); function hasHexColor(quasi) { - if (typeof quasi === 'string') { - const regex = /#([a-f0-9]{3}|[a-f0-9]{4}(?:[a-f0-9]{2}){0,2})\b/gi; - return !!quasi.match(regex); - } - return false; + const regex = /#([a-f0-9]{3}|[a-f0-9]{4}(?:[a-f0-9]{2}){0,2})\b/gi; + return !!quasi.match(regex); } function hasRgbColor(quasi) { - if (typeof quasi === 'string') { - const regex = /rgba?\((\d+),\s*(\d+),\s*(\d+)(?:,\s*(\d+(?:\.\d+)?))?\)/i; - return !!quasi.match(regex); - } - return false; + const regex = /rgba?\((\d+),\s*(\d+),\s*(\d+)(?:,\s*(\d+(?:\.\d+)?))?\)/i; + return !!quasi.match(regex); } function hasLiteralColor(quasi, strict = false) { - if (typeof quasi === 'string') { - // matches literal colors at the start or end of a CSS prop - return COLOR_KEYWORDS.some(color => { - const regexColon = new RegExp(`: ${color}`); - const regexSemicolon = new RegExp(` ${color};`); - return ( - !!quasi.match(regexColon) || - !!quasi.match(regexSemicolon) || - (strict && quasi === color) - ); - }); - } - return false; + // matches literal colors at the start or end of a CSS prop + return COLOR_KEYWORDS.some(color => { + const regexColon = new RegExp(`: ${color}`); + const regexSemicolon = new RegExp(` ${color};`); + return ( + !!quasi.match(regexColon) || + !!quasi.match(regexSemicolon) || + (strict && quasi === color) + ); + }); } const WARNING_MESSAGE = @@ -67,6 +58,14 @@ const WARNING_MESSAGE = module.exports = { rules: { 'no-literal-colors': { + meta: { + type: 'suggestion', + docs: { + description: + 'Disallow literal color values; use theme colors instead', + }, + schema: [], + }, create(context) { const warned = []; return { @@ -80,7 +79,7 @@ module.exports = { node?.parent?.type === 'TemplateLiteral'; const loc = node?.parent?.parent?.loc; const locId = loc && JSON.stringify(loc); - const hasWarned = warned.includes(locId); + const hasWarned = locId ? warned.includes(locId) : false; if ( !hasWarned && (isChildParentTagged || @@ -90,26 +89,42 @@ module.exports = { hasHexColor(rawValue) || hasRgbColor(rawValue)) ) { - context.report(node, loc, WARNING_MESSAGE); - warned.push(locId); + context.report({ + node, + ...(loc && { loc }), + message: WARNING_MESSAGE, + }); + if (locId) { + warned.push(locId); + } } }, Literal(node) { const value = node?.value; - const isParentProperty = node?.parent?.type === 'Property'; - const locId = JSON.stringify(node.loc); - const hasWarned = warned.includes(locId); + if (typeof value !== 'string') { + return; + } + const parent = node?.parent; + const isPropertyValue = + parent?.type === 'Property' && parent.value === node; + const locId = node?.loc ? JSON.stringify(node.loc) : null; + const hasWarned = locId ? warned.includes(locId) : false; if ( !hasWarned && - isParentProperty && - value && + isPropertyValue && (hasLiteralColor(value, true) || hasHexColor(value) || hasRgbColor(value)) ) { - context.report(node, node.loc, WARNING_MESSAGE); - warned.push(locId); + context.report({ + node, + ...(node.loc && { loc: node.loc }), + message: WARNING_MESSAGE, + }); + if (locId) { + warned.push(locId); + } } }, }; diff --git a/superset-frontend/packages/superset-ui-core/src/chart/types/VizType.ts b/superset-frontend/packages/superset-ui-core/src/chart/types/VizType.ts index 0052f1609275..58eb389949e9 100644 --- a/superset-frontend/packages/superset-ui-core/src/chart/types/VizType.ts +++ b/superset-frontend/packages/superset-ui-core/src/chart/types/VizType.ts @@ -38,6 +38,7 @@ export enum VizType { Heatmap = 'heatmap_v2', Histogram = 'histogram_v2', Horizon = 'horizon', + Kroki = 'kroki_svg', LegacyBubble = 'bubble', Line = 'echarts_timeseries_line', MapBox = 'mapbox', diff --git a/superset-frontend/src/visualizations/Kroki/KrokiChart.tsx b/superset-frontend/src/visualizations/Kroki/KrokiChart.tsx new file mode 100644 index 000000000000..bc1052328968 --- /dev/null +++ b/superset-frontend/src/visualizations/Kroki/KrokiChart.tsx @@ -0,0 +1,138 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +import { t } from '@apache-superset/core'; +import { SupersetClient } from '@superset-ui/core'; +import { css, styled } from '@apache-superset/core/ui'; +import { useEffect, useState } from 'react'; +import { ensureAppRoot } from 'src/utils/pathUtils'; + +import { KrokiChartProps, KrokiRenderApiResponse } from './types'; + +const Root = styled.div<{ height: number; width: number }>` + ${({ height, width, theme }) => css` + height: ${height}px; + width: ${width}px; + overflow: auto; + border: 1px solid ${theme.colorBorder}; + border-radius: ${theme.borderRadius}px; + padding: ${theme.sizeUnit * 2}px; + background: ${theme.colorBgContainer}; + `} +`; + +const Message = styled.div` + ${({ theme }) => css` + color: ${theme.colorTextSecondary}; + `} +`; + +const ErrorMessage = styled.div` + ${({ theme }) => css` + color: ${theme.colorErrorText}; + `} +`; + +export default function KrokiChart({ + width, + height, + formData, +}: KrokiChartProps) { + const [svgMarkup, setSvgMarkup] = useState(''); + const [error, setError] = useState(''); + const [isLoading, setIsLoading] = useState(false); + + const diagramType = formData.diagram_type || 'mermaid'; + const diagramSource = formData.diagram_source || ''; + + useEffect(() => { + if (!diagramSource.trim()) { + setSvgMarkup(''); + setError(t('Diagram source is empty.')); + return; + } + + let isActive = true; + const controller = new AbortController(); + + const renderDiagram = async () => { + setIsLoading(true); + setError(''); + try { + const { json } = await SupersetClient.post({ + endpoint: ensureAppRoot('/api/v1/kroki/render/'), + jsonPayload: { + diagram_type: diagramType, + diagram_source: diagramSource, + output_format: 'svg', + }, + signal: controller.signal, + }); + + const apiResponse = json as KrokiRenderApiResponse; + const svg = apiResponse.result?.svg || ''; + + if (!svg.trim()) { + throw new Error(t('Kroki renderer returned an empty SVG response.')); + } + + if (isActive) { + setSvgMarkup(svg); + setError(''); + } + } catch (caughtError) { + if (controller.signal.aborted) { + return; + } + if (isActive) { + const fallback = t('Failed to render diagram.'); + setError( + caughtError instanceof Error && caughtError.message + ? caughtError.message + : fallback, + ); + setSvgMarkup(''); + } + } finally { + if (isActive) { + setIsLoading(false); + } + } + }; + + renderDiagram(); + + return () => { + isActive = false; + controller.abort(); + }; + }, [diagramSource, diagramType]); + + return ( + + {isLoading && {t('Rendering diagram...')}} + {!isLoading && error && {error}} + {!isLoading && !error && ( +
+ )} + + ); +} diff --git a/superset-frontend/src/visualizations/Kroki/images/thumbnail-dark.svg b/superset-frontend/src/visualizations/Kroki/images/thumbnail-dark.svg new file mode 100644 index 000000000000..83802b6861b2 --- /dev/null +++ b/superset-frontend/src/visualizations/Kroki/images/thumbnail-dark.svg @@ -0,0 +1,36 @@ + + + + + + + + + + + + Kroki SVG + + + + + + + diff --git a/superset-frontend/src/visualizations/Kroki/images/thumbnail.svg b/superset-frontend/src/visualizations/Kroki/images/thumbnail.svg new file mode 100644 index 000000000000..61f7cf4f03ae --- /dev/null +++ b/superset-frontend/src/visualizations/Kroki/images/thumbnail.svg @@ -0,0 +1,36 @@ + + + + + + + + + + + + Kroki SVG + + + + + + + diff --git a/superset-frontend/src/visualizations/Kroki/index.ts b/superset-frontend/src/visualizations/Kroki/index.ts new file mode 100644 index 000000000000..ae156108afc5 --- /dev/null +++ b/superset-frontend/src/visualizations/Kroki/index.ts @@ -0,0 +1,49 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +import { t } from '@apache-superset/core'; +import { ChartMetadata, ChartPlugin } from '@superset-ui/core'; + +import thumbnail from './images/thumbnail.svg'; +import thumbnailDark from './images/thumbnail-dark.svg'; +import buildQuery from './plugin/buildQuery'; +import controlPanel from './plugin/controlPanel'; +import transformProps from './plugin/transformProps'; + +const metadata = new ChartMetadata({ + category: t('Other'), + name: t('Kroki Diagram (SVG)'), + description: t( + 'Render diagrams through a Kroki sidecar and display SVG output.', + ), + tags: [t('Diagram'), t('SVG'), t('Markdown')], + thumbnail, + thumbnailDark, +}); + +export default class KrokiChartPlugin extends ChartPlugin { + constructor() { + super({ + metadata, + buildQuery, + controlPanel, + transformProps, + loadChart: () => import('./KrokiChart'), + }); + } +} diff --git a/superset-frontend/src/visualizations/Kroki/plugin/buildQuery.ts b/superset-frontend/src/visualizations/Kroki/plugin/buildQuery.ts new file mode 100644 index 000000000000..1fada1481626 --- /dev/null +++ b/superset-frontend/src/visualizations/Kroki/plugin/buildQuery.ts @@ -0,0 +1,24 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +import { buildQueryContext, QueryFormData } from '@superset-ui/core'; + +// Kroki rendering is handled by a sidecar endpoint, so no datasource query is required. +export default function buildQuery(formData: QueryFormData) { + return buildQueryContext(formData, () => []); +} diff --git a/superset-frontend/src/visualizations/Kroki/plugin/controlPanel.ts b/superset-frontend/src/visualizations/Kroki/plugin/controlPanel.ts new file mode 100644 index 000000000000..be53cf663679 --- /dev/null +++ b/superset-frontend/src/visualizations/Kroki/plugin/controlPanel.ts @@ -0,0 +1,81 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +import { + ControlPanelConfig, + formatSelectOptions, +} from '@superset-ui/chart-controls'; +import { t } from '@apache-superset/core'; +import { validateNonEmpty } from '@superset-ui/core'; + +const diagramTypes = formatSelectOptions([ + 'mermaid', + 'plantuml', + 'graphviz', + 'd2', + 'bpmn', + 'nomnoml', + 'svgbob', + 'vega', + 'vegalite', +]); + +const controlPanel: ControlPanelConfig = { + controlPanelSections: [ + { + label: t('Diagram'), + expanded: true, + controlSetRows: [ + [ + { + name: 'diagram_type', + config: { + type: 'SelectControl', + label: t('Diagram type'), + description: t('Kroki renderer used for SVG output.'), + clearable: false, + default: 'mermaid', + choices: diagramTypes, + renderTrigger: true, + }, + }, + ], + [ + { + name: 'diagram_source', + config: { + type: 'TextAreaControl', + language: 'markdown', + label: t('Diagram source'), + description: t( + 'Diagram source sent to the Kroki sidecar and rendered as SVG.', + ), + default: + 'graph TD\n Client[Client] --> Superset[Superset]\n Superset --> Kroki[Kroki Sidecar]', + rows: 16, + renderTrigger: true, + validators: [validateNonEmpty], + }, + }, + ], + ], + }, + ], +}; + +export default controlPanel; diff --git a/superset-frontend/src/visualizations/Kroki/plugin/transformProps.ts b/superset-frontend/src/visualizations/Kroki/plugin/transformProps.ts new file mode 100644 index 000000000000..0147ebbe6410 --- /dev/null +++ b/superset-frontend/src/visualizations/Kroki/plugin/transformProps.ts @@ -0,0 +1,33 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +import { ChartProps } from '@superset-ui/core'; + +import { KrokiChartProps, KrokiFormData } from '../types'; + +export default function transformProps( + chartProps: ChartProps, +): KrokiChartProps { + const { width, height, formData } = chartProps; + + return { + width, + height, + formData: formData as KrokiFormData, + }; +} diff --git a/superset-frontend/src/visualizations/Kroki/types.ts b/superset-frontend/src/visualizations/Kroki/types.ts new file mode 100644 index 000000000000..a3a6afa632df --- /dev/null +++ b/superset-frontend/src/visualizations/Kroki/types.ts @@ -0,0 +1,38 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +import { QueryFormData } from '@superset-ui/core'; + +export interface KrokiFormData extends QueryFormData { + diagram_type?: string; + diagram_source?: string; +} + +export interface KrokiChartProps { + width: number; + height: number; + formData: KrokiFormData; +} + +export interface KrokiRenderApiResponse { + result?: { + diagram_type: string; + output_format: string; + svg: string; + }; +} diff --git a/superset-frontend/src/visualizations/presets/MainPreset.ts b/superset-frontend/src/visualizations/presets/MainPreset.ts index c28dea349453..ae87b7c22908 100644 --- a/superset-frontend/src/visualizations/presets/MainPreset.ts +++ b/superset-frontend/src/visualizations/presets/MainPreset.ts @@ -88,6 +88,7 @@ import { HandlebarsChartPlugin } from '@superset-ui/plugin-chart-handlebars'; import { ChartCustomizationPlugins, FilterPlugins } from 'src/constants'; import AgGridTableChartPlugin from '@superset-ui/plugin-chart-ag-grid-table'; import TimeTableChartPlugin from '../TimeTable'; +import KrokiChartPlugin from '../Kroki'; export default class MainPreset extends Preset { constructor() { @@ -195,6 +196,7 @@ export default class MainPreset extends Preset { new EchartsTreeChartPlugin().configure({ key: VizType.Tree }), new EchartsSunburstChartPlugin().configure({ key: VizType.Sunburst }), new HandlebarsChartPlugin().configure({ key: VizType.Handlebars }), + new KrokiChartPlugin().configure({ key: VizType.Kroki }), new EchartsBubbleChartPlugin().configure({ key: VizType.Bubble }), new CartodiagramPlugin({ defaultLayers: [ diff --git a/superset/config.py b/superset/config.py index 9731d80bd587..2b11a2b14657 100644 --- a/superset/config.py +++ b/superset/config.py @@ -1945,6 +1945,38 @@ def EMAIL_HEADER_MUTATOR( # pylint: disable=invalid-name,unused-argument # noq # For workspaces with 10k+ channels, consider increasing to 10 SLACK_API_RATE_LIMIT_RETRY_COUNT = 2 +# Kroki sidecar integration for SVG diagram rendering. +KROKI_BASE_URL = "http://localhost:8000" +KROKI_REQUEST_TIMEOUT = 10 +KROKI_MAX_SOURCE_LENGTH = 200000 +KROKI_ALLOWED_DIAGRAM_TYPES = ( + "actdiag", + "blockdiag", + "bpmn", + "bytefield", + "c4plantuml", + "d2", + "dbml", + "ditaa", + "erd", + "excalidraw", + "graphviz", + "mermaid", + "nomnoml", + "nwdiag", + "packetdiag", + "pikchr", + "plantuml", + "rackdiag", + "seqdiag", + "structurizr", + "svgbob", + "umlet", + "vega", + "vegalite", + "wavedrom", +) + # The webdriver to use for generating reports. Use one of the following # firefox # Requires: geckodriver and firefox installations diff --git a/superset/mcp_service/README.md b/superset/mcp_service/README.md index d1b4eed9e7b2..dae1da307aa9 100644 --- a/superset/mcp_service/README.md +++ b/superset/mcp_service/README.md @@ -85,7 +85,7 @@ git clone https://github.com/apache/superset.git cd superset # 2. Set up Python environment (Python 3.10 or 3.11 required) -python3 -m venv venv +uv venv venv source venv/bin/activate # 3. Install dependencies diff --git a/superset/mcp_service/app.py b/superset/mcp_service/app.py index 9624964bc724..04482db1b817 100644 --- a/superset/mcp_service/app.py +++ b/superset/mcp_service/app.py @@ -58,12 +58,21 @@ def get_default_instructions(branding: str = "Apache Superset") -> str: - list_datasets: List datasets with advanced filters (1-based pagination) - get_dataset_info: Get detailed dataset information by ID (includes columns/metrics) +DataFrame Virtual Dataset Management: +- ingest_dataframe: Ingest Arrow IPC DataFrame payloads as virtual datasets +- list_source_capabilities: List dataframe source capabilities and pushdown support +- list_virtual_datasets: List accessible in-memory virtual datasets +- query_virtual_dataset: Run read-only SQL against a virtual dataset +- query_prometheus: Query Prometheus metrics and ingest as virtual dataset +- query_datafusion: Query Parquet/Arrow/virtual sources using DataFusion SQL +- remove_virtual_dataset: Delete a virtual dataset and release memory + Chart Management: - list_charts: List charts with advanced filters (1-based pagination) - get_chart_info: Get detailed chart information by ID - get_chart_preview: Get a visual preview of a chart with image URL - get_chart_data: Get underlying chart data in text-friendly format -- generate_chart: Create and save a new chart permanently +- generate_chart: Create chart previews; optionally save as a permanent chart - generate_explore_link: Create an interactive explore URL (preferred for exploration) - update_chart: Update existing saved chart configuration - update_chart_preview: Update cached chart preview without saving @@ -95,6 +104,18 @@ def get_default_instructions(branding: str = "Apache Superset") -> str: 3. generate_explore_link(dataset_id, config) -> preview interactively 4. generate_chart(dataset_id, config, save_chart=True) -> save permanently +To visualize a DataFrame without creating a database table: +1. ingest_dataframe(name, data, ttl_minutes) -> create virtual dataset +2. list_virtual_datasets -> confirm dataset and copy id +3. query_virtual_dataset(dataset_id, sql) -> inspect transformed results +4. generate_chart(dataset_id="virtual:...", save_chart=False) -> preview table/ascii/vega outputs + +To bring in external metric/file sources: +1. list_source_capabilities -> choose a source path based on connector capabilities +2. query_prometheus(base_url, promql, query_type) -> flatten metric series +3. query_datafusion(sql, sources) -> query parquet/arrow inputs +4. set ingest_as_virtual_dataset/ingest_result=true for cross-tool reuse + To explore data with SQL: 1. get_instance_info -> find database_id 2. execute_sql(database_id, sql) -> run query @@ -326,6 +347,15 @@ def create_mcp_app( get_dataset_info, list_datasets, ) +from superset.mcp_service.dataframe.tool import ( # noqa: F401, E402 + ingest_dataframe, + list_source_capabilities, + list_virtual_datasets, + query_datafusion, + query_prometheus, + query_virtual_dataset, + remove_virtual_dataset, +) from superset.mcp_service.explore.tool import ( # noqa: F401, E402 generate_explore_link, ) diff --git a/superset/mcp_service/chart/chart_utils.py b/superset/mcp_service/chart/chart_utils.py index 748e770a4d8f..01a6293b0709 100644 --- a/superset/mcp_service/chart/chart_utils.py +++ b/superset/mcp_service/chart/chart_utils.py @@ -32,7 +32,13 @@ TableChartConfig, XYChartConfig, ) +from superset.mcp_service.dataframe.identifiers import ( + extract_virtual_dataset_id, + is_virtual_dataset_identifier, +) +from superset.mcp_service.dataframe.registry import get_registry from superset.mcp_service.utils.url_utils import get_superset_base_url +from superset.utils.core import get_user_id from superset.utils import json logger = logging.getLogger(__name__) @@ -54,6 +60,11 @@ def generate_explore_link(dataset_id: int | str, form_data: Dict[str, Any]) -> s numeric_dataset_id = None dataset = None + if is_virtual_dataset_identifier(dataset_id): + # Virtual datasets do not map to persisted Superset datasource IDs, + # so they cannot produce an /explore form_data_key URL. + return "" + try: if isinstance(dataset_id, int) or ( isinstance(dataset_id, str) and dataset_id.isdigit() @@ -130,6 +141,27 @@ def is_column_truly_temporal(column_name: str, dataset_id: int | str | None) -> return True # Default to temporal if we can't check (backward compatible) try: + raw_virtual_dataset_id = extract_virtual_dataset_id(dataset_id) + if raw_virtual_dataset_id: + import pyarrow as pa + + try: + user_id = get_user_id() + except Exception: + user_id = None + registry = get_registry() + virtual_dataset = registry.get( + raw_virtual_dataset_id, + session_id=None, + user_id=user_id, + ) + if not virtual_dataset: + return True + for field in virtual_dataset.schema: + if field.name.lower() == column_name.lower(): + return pa.types.is_temporal(field.type) + return True + # Find dataset if isinstance(dataset_id, int) or ( isinstance(dataset_id, str) and dataset_id.isdigit() diff --git a/superset/mcp_service/chart/preview_utils.py b/superset/mcp_service/chart/preview_utils.py index 677d3034fd4a..e17622328bd7 100644 --- a/superset/mcp_service/chart/preview_utils.py +++ b/superset/mcp_service/chart/preview_utils.py @@ -113,18 +113,7 @@ def generate_preview_from_form_data( query_result = result["queries"][0] data = query_result.get("data", []) - # Generate preview based on format - if preview_format == "ascii": - return _generate_ascii_preview_from_data(data, form_data) - elif preview_format == "table": - return _generate_table_preview_from_data(data, form_data) - elif preview_format == "vega_lite": - return _generate_vega_lite_preview_from_data(data, form_data) - else: - return ChartError( - error=f"Unsupported preview format: {preview_format}", - error_type="UnsupportedFormat", - ) + return generate_preview_from_data(data, form_data, preview_format) except Exception as e: logger.error("Preview generation from form data failed: %s", e) @@ -133,6 +122,22 @@ def generate_preview_from_form_data( ) +def generate_preview_from_data( + data: List[Dict[str, Any]], form_data: Dict[str, Any], preview_format: str +) -> Any: + """Generate a preview payload from already queried chart data.""" + if preview_format == "ascii": + return _generate_ascii_preview_from_data(data, form_data) + if preview_format == "table": + return _generate_table_preview_from_data(data, form_data) + if preview_format == "vega_lite": + return _generate_vega_lite_preview_from_data(data, form_data) + return ChartError( + error=f"Unsupported preview format: {preview_format}", + error_type="UnsupportedFormat", + ) + + def _generate_ascii_preview_from_data( data: List[Dict[str, Any]], form_data: Dict[str, Any] ) -> ASCIIPreview: diff --git a/superset/mcp_service/chart/tool/generate_chart.py b/superset/mcp_service/chart/tool/generate_chart.py index ee555ca1d153..00e0d4296b13 100644 --- a/superset/mcp_service/chart/tool/generate_chart.py +++ b/superset/mcp_service/chart/tool/generate_chart.py @@ -38,6 +38,14 @@ GenerateChartResponse, PerformanceMetadata, ) +from superset.mcp_service.chart.virtual_dataset_bridge import ( + query_virtual_dataset_with_form_data, + resolve_virtual_dataset, +) +from superset.mcp_service.dataframe.identifiers import ( + is_virtual_dataset_identifier, +) +from superset.mcp_service.dataframe.tool.context import resolve_session_and_user from superset.mcp_service.utils.schema_utils import parse_request from superset.mcp_service.utils.url_utils import get_superset_base_url from superset.utils import json @@ -125,6 +133,9 @@ async def generate_chart( # noqa: C901 # Track runtime warnings to include in response runtime_warnings: list[str] = [] + session_id, user_id = resolve_session_and_user(ctx) + is_virtual_dataset = is_virtual_dataset_identifier(request.dataset_id) + virtual_dataset = None try: # Run comprehensive validation pipeline @@ -135,12 +146,15 @@ async def generate_chart( # noqa: C901 from superset.mcp_service.chart.validation import ValidationPipeline validation_result = ValidationPipeline.validate_request_with_warnings( - request.model_dump() + request.model_dump(), + session_id=session_id, + user_id=user_id, ) if validation_result.is_valid and validation_result.request is not None: # Use the validated request going forward request = validation_result.request + is_virtual_dataset = is_virtual_dataset_identifier(request.dataset_id) # Capture runtime warnings (informational, not blocking) if validation_result.warnings: @@ -184,8 +198,85 @@ async def generate_chart( # noqa: C901 explore_url = None form_data_key = None + if is_virtual_dataset: + virtual_dataset = resolve_virtual_dataset( + request.dataset_id, + session_id=session_id, + user_id=user_id, + ) + if virtual_dataset is None: + execution_time = int((time.time() - start_time) * 1000) + from superset.mcp_service.common.error_schemas import ( + ChartGenerationError, + ) + + error = ChartGenerationError( + error_type="dataset_not_found", + message=f"Virtual dataset not found: {request.dataset_id}", + details=( + f"No accessible virtual dataset found for identifier " + f"'{request.dataset_id}'." + ), + suggestions=[ + "Use list_virtual_datasets to see accessible virtual datasets", + "Check dataset TTL and recreate it if expired", + "Use the prefixed format virtual:{uuid}", + ], + error_code="VIRTUAL_DATASET_NOT_FOUND", + ) + return GenerateChartResponse.model_validate( + { + "chart": None, + "error": error.model_dump(), + "performance": { + "query_duration_ms": execution_time, + "cache_status": "error", + "optimization_suggestions": [], + }, + "success": False, + "schema_version": "2.0", + "api_version": "v1", + } + ) + # Save chart by default (unless save_chart=False) if request.save_chart: + if is_virtual_dataset: + execution_time = int((time.time() - start_time) * 1000) + from superset.mcp_service.common.error_schemas import ( + ChartGenerationError, + ) + + error = ChartGenerationError( + error_type="unsupported_operation", + message=( + "Saving charts from virtual datasets is not supported" + ), + details=( + "Virtual datasets are in-memory and session-scoped. " + "Use save_chart=False to generate previews." + ), + suggestions=[ + "Set save_chart=False and generate_preview=True", + "Persist data to a Superset dataset before saving chart", + ], + error_code="VIRTUAL_DATASET_SAVE_UNSUPPORTED", + ) + return GenerateChartResponse.model_validate( + { + "chart": None, + "error": error.model_dump(), + "performance": { + "query_duration_ms": execution_time, + "cache_status": "error", + "optimization_suggestions": [], + }, + "success": False, + "schema_version": "2.0", + "api_version": "v1", + } + ) + await ctx.report_progress(2, 5, "Creating chart in database") from superset.commands.chart.create import CreateChartCommand @@ -337,19 +428,25 @@ async def generate_chart( # noqa: C901 # form_data_key remains None but chart is still valid else: await ctx.report_progress(2, 5, "Generating temporary chart preview") - # Generate explore link with cached form_data for preview-only mode - from superset.mcp_service.chart.chart_utils import generate_explore_link + if is_virtual_dataset: + runtime_warnings.append( + "Explore URLs are unavailable for virtual datasets. " + "Use preview payloads instead." + ) + else: + # Generate explore link with cached form_data for preview-only mode + from superset.mcp_service.chart.chart_utils import generate_explore_link - explore_url = generate_explore_link(request.dataset_id, form_data) - await ctx.debug("Generated explore link: explore_url=%s" % (explore_url,)) + explore_url = generate_explore_link(request.dataset_id, form_data) + await ctx.debug("Generated explore link: explore_url=%s" % (explore_url,)) - # Extract form_data_key from the explore URL using proper URL parsing - if explore_url: - parsed = urlparse(explore_url) - query_params = parse_qs(parsed.query) - form_data_key_list = query_params.get("form_data_key", []) - if form_data_key_list: - form_data_key = form_data_key_list[0] + # Extract form_data_key from the explore URL using proper URL parsing + if explore_url: + parsed = urlparse(explore_url) + query_params = parse_qs(parsed.query) + form_data_key_list = query_params.get("form_data_key", []) + if form_data_key_list: + form_data_key = form_data_key_list[0] # Generate semantic analysis capabilities = analyze_chart_capabilities(chart, request.config) @@ -383,60 +480,101 @@ async def generate_chart( # noqa: C901 "Generating previews: formats=%s" % (str(request.preview_formats),) ) try: - for format_type in request.preview_formats: - await ctx.debug( - "Processing preview format: format=%s" % (format_type,) + if is_virtual_dataset and not chart_id: + if virtual_dataset is None: + raise RuntimeError( + "Virtual dataset missing during preview generation" + ) + + from superset.mcp_service.chart.preview_utils import ( + generate_preview_from_data, ) - if chart_id: - # For saved charts, use the existing preview generation - from superset.mcp_service.chart.tool.get_chart_preview import ( - _get_chart_preview_internal, - GetChartPreviewRequest, + if "url" in request.preview_formats: + runtime_warnings.append( + "URL previews are unavailable for virtual datasets." ) - preview_request = GetChartPreviewRequest( - identifier=str(chart_id), format=format_type + preview_formats = [ + format_type + for format_type in request.preview_formats + if format_type in {"ascii", "table", "vega_lite"} + ] + if not preview_formats: + preview_formats = ["table"] + runtime_warnings.append( + "URL previews are unavailable for virtual datasets; " + "returned a table preview instead." ) - preview_result = await _get_chart_preview_internal( - preview_request, ctx + + preview_rows, _ = query_virtual_dataset_with_form_data( + virtual_dataset, + form_data=form_data, + limit=1000, + ) + for format_type in preview_formats: + preview_result = generate_preview_from_data( + data=preview_rows, + form_data=form_data, + preview_format=format_type, ) + if not hasattr(preview_result, "error"): + previews[format_type] = preview_result + else: + for format_type in request.preview_formats: + await ctx.debug( + "Processing preview format: format=%s" % (format_type,) + ) + + if chart_id: + # For saved charts, use the existing preview generation + from superset.mcp_service.chart.tool.get_chart_preview import ( + _get_chart_preview_internal, + GetChartPreviewRequest, + ) - if hasattr(preview_result, "content"): - previews[format_type] = preview_result.content - else: - # For preview-only mode (save_chart=false) - # Note: Screenshot-based URL previews are not supported. - # Use the explore_url to view the chart interactively. - if format_type in ["ascii", "table", "vega_lite"]: - # Generate preview from form data without saved chart - from superset.mcp_service.chart.preview_utils import ( - generate_preview_from_form_data, + preview_request = GetChartPreviewRequest( + identifier=str(chart_id), format=format_type + ) + preview_result = await _get_chart_preview_internal( + preview_request, ctx ) - # Convert dataset_id to int only if it's numeric - if ( - isinstance(request.dataset_id, str) - and request.dataset_id.isdigit() - ): - dataset_id_for_preview = int(request.dataset_id) - elif isinstance(request.dataset_id, int): - dataset_id_for_preview = request.dataset_id - else: - # Skip preview generation for non-numeric dataset IDs - logger.warning( - "Cannot generate preview for non-numeric " + if hasattr(preview_result, "content"): + previews[format_type] = preview_result.content + else: + # For preview-only mode (save_chart=false) + # Note: Screenshot-based URL previews are not supported. + # Use the explore_url to view the chart interactively. + if format_type in ["ascii", "table", "vega_lite"]: + # Generate preview from form data without saved chart + from superset.mcp_service.chart.preview_utils import ( + generate_preview_from_form_data, ) - continue - preview_result = generate_preview_from_form_data( - form_data=form_data, - dataset_id=dataset_id_for_preview, - preview_format=format_type, - ) + # Convert dataset_id to int only if it's numeric + if ( + isinstance(request.dataset_id, str) + and request.dataset_id.isdigit() + ): + dataset_id_for_preview = int(request.dataset_id) + elif isinstance(request.dataset_id, int): + dataset_id_for_preview = request.dataset_id + else: + # Skip preview generation for non-numeric dataset IDs + logger.warning( + "Cannot generate preview for non-numeric " + ) + continue + + preview_result = generate_preview_from_form_data( + form_data=form_data, + dataset_id=dataset_id_for_preview, + preview_format=format_type, + ) - if not hasattr(preview_result, "error"): - previews[format_type] = preview_result + if not hasattr(preview_result, "error"): + previews[format_type] = preview_result except Exception as e: # Log warning but don't fail the entire request diff --git a/superset/mcp_service/chart/validation/dataset_validator.py b/superset/mcp_service/chart/validation/dataset_validator.py index b03d0ffe9c00..27357153277a 100644 --- a/superset/mcp_service/chart/validation/dataset_validator.py +++ b/superset/mcp_service/chart/validation/dataset_validator.py @@ -22,13 +22,17 @@ import difflib import logging -from typing import Dict, List, Tuple +from typing import Any, Dict, List, Tuple + +import pyarrow as pa from superset.mcp_service.chart.schemas import ( ColumnRef, TableChartConfig, XYChartConfig, ) +from superset.mcp_service.dataframe.identifiers import extract_virtual_dataset_id +from superset.mcp_service.dataframe.registry import get_registry from superset.mcp_service.common.error_schemas import ( ChartGenerationError, ColumnSuggestion, @@ -43,7 +47,10 @@ class DatasetValidator: @staticmethod def validate_against_dataset( - config: TableChartConfig | XYChartConfig, dataset_id: int | str + config: TableChartConfig | XYChartConfig, + dataset_id: int | str, + session_id: str | None = None, + user_id: int | None = None, ) -> Tuple[bool, ChartGenerationError | None]: """ Validate chart configuration against dataset schema. @@ -56,7 +63,11 @@ def validate_against_dataset( Tuple of (is_valid, error) """ # Get dataset context - dataset_context = DatasetValidator._get_dataset_context(dataset_id) + dataset_context = DatasetValidator._get_dataset_context( + dataset_id, + session_id=session_id, + user_id=user_id, + ) if not dataset_context: from superset.mcp_service.utils.error_builder import ( ChartErrorBuilder, @@ -98,9 +109,20 @@ def validate_against_dataset( return True, None @staticmethod - def _get_dataset_context(dataset_id: int | str) -> DatasetContext | None: + def _get_dataset_context( + dataset_id: int | str, + session_id: str | None = None, + user_id: int | None = None, + ) -> DatasetContext | None: """Get dataset context with column information.""" try: + if virtual_dataset_context := DatasetValidator._get_virtual_dataset_context( + dataset_id=dataset_id, + session_id=session_id, + user_id=user_id, + ): + return virtual_dataset_context + from superset.daos.dataset import DatasetDAO # Find dataset @@ -159,6 +181,52 @@ def _get_dataset_context(dataset_id: int | str) -> DatasetContext | None: logger.error("Error getting dataset context for %s: %s", dataset_id, e) return None + @staticmethod + def _get_virtual_dataset_context( + dataset_id: int | str, + session_id: str | None, + user_id: int | None, + ) -> DatasetContext | None: + """Build dataset context for prefixed virtual datasets.""" + raw_dataset_id = extract_virtual_dataset_id(dataset_id) + if raw_dataset_id is None: + return None + + registry = get_registry() + dataset = registry.get( + raw_dataset_id, + session_id=session_id, + user_id=user_id, + ) + if dataset is None: + return None + + available_columns: list[dict[str, Any]] = [] + for field in dataset.schema: + is_temporal = pa.types.is_temporal(field.type) + is_numeric = ( + pa.types.is_integer(field.type) + or pa.types.is_floating(field.type) + or pa.types.is_decimal(field.type) + ) + available_columns.append( + { + "name": field.name, + "type": str(field.type), + "is_temporal": is_temporal, + "is_numeric": is_numeric, + } + ) + + return DatasetContext( + id=dataset.id, + table_name=dataset.name, + schema="virtual", + database_name="virtual_arrow", + available_columns=available_columns, + available_metrics=[], + ) + @staticmethod def _extract_column_references( config: TableChartConfig | XYChartConfig, diff --git a/superset/mcp_service/chart/validation/pipeline.py b/superset/mcp_service/chart/validation/pipeline.py index 948f9d2e62df..778cc7fedfb2 100644 --- a/superset/mcp_service/chart/validation/pipeline.py +++ b/superset/mcp_service/chart/validation/pipeline.py @@ -128,6 +128,8 @@ class ValidationPipeline: @staticmethod def validate_request( request_data: Dict[str, Any], + session_id: str | None = None, + user_id: int | None = None, ) -> Tuple[bool, GenerateChartRequest | None, ChartGenerationError | None]: """ Validate a chart generation request through all validation layers. @@ -140,12 +142,18 @@ def validate_request( Note: Use validate_request_with_warnings() to also get runtime warnings. """ - result = ValidationPipeline.validate_request_with_warnings(request_data) + result = ValidationPipeline.validate_request_with_warnings( + request_data, + session_id=session_id, + user_id=user_id, + ) return result.is_valid, result.request, result.error @staticmethod def validate_request_with_warnings( request_data: Dict[str, Any], + session_id: str | None = None, + user_id: int | None = None, ) -> ValidationResult: """ Validate a chart generation request and return warnings as metadata. @@ -170,7 +178,10 @@ def validate_request_with_warnings( # Layer 2: Dataset validation is_valid, error = ValidationPipeline._validate_dataset( - request.config, request.dataset_id + request.config, + request.dataset_id, + session_id=session_id, + user_id=user_id, ) if not is_valid: return ValidationResult(is_valid=False, request=request, error=error) @@ -203,13 +214,21 @@ def validate_request_with_warnings( @staticmethod def _validate_dataset( - config: ChartConfig, dataset_id: int | str + config: ChartConfig, + dataset_id: int | str, + session_id: str | None = None, + user_id: int | None = None, ) -> Tuple[bool, ChartGenerationError | None]: """Validate configuration against dataset schema.""" try: from .dataset_validator import DatasetValidator - return DatasetValidator.validate_against_dataset(config, dataset_id) + return DatasetValidator.validate_against_dataset( + config, + dataset_id, + session_id=session_id, + user_id=user_id, + ) except ImportError: # Skip if dataset validator not available logger.warning( diff --git a/superset/mcp_service/chart/virtual_dataset_bridge.py b/superset/mcp_service/chart/virtual_dataset_bridge.py new file mode 100644 index 000000000000..fc5521c3cbca --- /dev/null +++ b/superset/mcp_service/chart/virtual_dataset_bridge.py @@ -0,0 +1,235 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Bridge helpers for using MCP virtual datasets in chart workflows. +""" + +from __future__ import annotations + +from typing import Any + +from superset.mcp_service.dataframe.identifiers import ( + extract_virtual_dataset_id, +) +from superset.mcp_service.dataframe.registry import ( + VirtualDataset, + get_registry, +) +from superset.mcp_service.dataframe.tool.common import table_to_columns, table_to_rows + +SUPPORTED_FILTER_OPERATORS = {"=", "==", ">", "<", ">=", "<=", "!="} + + +def resolve_virtual_dataset( + dataset_id: int | str, + session_id: str | None, + user_id: int | None, +) -> VirtualDataset | None: + """Resolve a prefixed virtual dataset identifier to a registry entry.""" + raw_dataset_id = extract_virtual_dataset_id(dataset_id) + if raw_dataset_id is None: + return None + registry = get_registry() + return registry.get(raw_dataset_id, session_id=session_id, user_id=user_id) + + +def _quote_identifier(identifier: str) -> str: + escaped = identifier.replace('"', '""') + return f'"{escaped}"' + + +def _to_sql_literal(value: Any) -> str: + if value is None: + return "NULL" + if isinstance(value, bool): + return "TRUE" if value else "FALSE" + if isinstance(value, (int, float)): + return str(value) + escaped = str(value).replace("'", "''") + return f"'{escaped}'" + + +def _build_where_clause(form_data: dict[str, Any]) -> str: + filters = form_data.get("adhoc_filters") or [] + clauses: list[str] = [] + for filter_config in filters: + if not isinstance(filter_config, dict): + continue + if filter_config.get("expressionType") != "SIMPLE": + continue + column = filter_config.get("subject") + operator = filter_config.get("operator") + comparator = filter_config.get("comparator") + if not isinstance(column, str): + continue + if not isinstance(operator, str): + continue + if operator not in SUPPORTED_FILTER_OPERATORS: + continue + + normalized_operator = "=" if operator == "==" else operator + if comparator is None: + if normalized_operator == "=": + clauses.append(f"{_quote_identifier(column)} IS NULL") + elif normalized_operator == "!=": + clauses.append(f"{_quote_identifier(column)} IS NOT NULL") + continue + + clauses.append( + f"{_quote_identifier(column)} {normalized_operator} " + f"{_to_sql_literal(comparator)}" + ) + + if not clauses: + return "" + return " WHERE " + " AND ".join(clauses) + + +def _unique_dimension_columns(form_data: dict[str, Any]) -> list[str]: + columns: list[str] = [] + + query_mode = form_data.get("query_mode") + if query_mode == "raw": + raw_columns = form_data.get("all_columns") or form_data.get("columns") or [] + for value in raw_columns: + if isinstance(value, str) and value not in columns: + columns.append(value) + return columns + + x_axis = form_data.get("x_axis") + if isinstance(x_axis, str) and x_axis not in columns: + columns.append(x_axis) + elif isinstance(x_axis, dict): + x_name = x_axis.get("column_name") + if isinstance(x_name, str) and x_name not in columns: + columns.append(x_name) + + for group_column in form_data.get("groupby", []) or []: + if isinstance(group_column, str) and group_column not in columns: + columns.append(group_column) + + return columns + + +def _metric_sql_expressions(form_data: dict[str, Any]) -> list[str]: + metrics = form_data.get("metrics") or [] + expressions: list[str] = [] + + for metric in metrics: + if not isinstance(metric, dict): + continue + aggregate = str(metric.get("aggregate") or "SUM").upper() + if aggregate not in { + "SUM", + "COUNT", + "AVG", + "MIN", + "MAX", + "COUNT_DISTINCT", + "STDDEV", + "VAR", + "MEDIAN", + "PERCENTILE", + }: + continue + + column_name: str | None = None + metric_column = metric.get("column") + if isinstance(metric_column, dict): + raw_column_name = metric_column.get("column_name") + if isinstance(raw_column_name, str): + column_name = raw_column_name + + if aggregate == "COUNT" and column_name is None: + base_expression = "COUNT(*)" + default_label = "COUNT(*)" + else: + if column_name is None: + continue + quoted_column = _quote_identifier(column_name) + if aggregate == "COUNT_DISTINCT": + base_expression = f"COUNT(DISTINCT {quoted_column})" + elif aggregate == "PERCENTILE": + base_expression = f"QUANTILE_CONT({quoted_column}, 0.5)" + else: + base_expression = f"{aggregate}({quoted_column})" + default_label = f"{aggregate}({column_name})" + + label = metric.get("label") + alias = label if isinstance(label, str) and label else default_label + expressions.append(f"{base_expression} AS {_quote_identifier(alias)}") + + return expressions + + +def build_virtual_dataset_query(form_data: dict[str, Any], limit: int = 1000) -> str: + """Build a DuckDB SQL query from Superset form_data for virtual datasets.""" + safe_limit = max(1, min(limit, 10000)) + where_clause = _build_where_clause(form_data) + query_mode = form_data.get("query_mode") + + if query_mode == "raw": + columns = _unique_dimension_columns(form_data) + select_clause = ( + ", ".join(_quote_identifier(column) for column in columns) if columns else "*" + ) + return f"SELECT {select_clause} FROM data{where_clause} LIMIT {safe_limit}" + + dimensions = _unique_dimension_columns(form_data) + metrics = _metric_sql_expressions(form_data) + select_parts = [ + *(_quote_identifier(column) for column in dimensions), + *metrics, + ] + if not select_parts: + select_parts = ["*"] + + sql = f"SELECT {', '.join(select_parts)} FROM data{where_clause}" + if dimensions and metrics: + sql += " GROUP BY " + ", ".join(_quote_identifier(column) for column in dimensions) + sql += f" LIMIT {safe_limit}" + return sql + + +def query_virtual_dataset_with_form_data( + dataset: VirtualDataset, + form_data: dict[str, Any], + limit: int = 1000, +) -> tuple[list[dict[str, str | int | float | bool | None]], list[dict[str, str]]]: + """ + Execute a chart-style query for a virtual dataset and return rows/columns. + """ + try: + import duckdb + except ImportError as ex: + raise RuntimeError( + "DuckDB is required to render chart previews for virtual datasets. " + "Install with: pip install duckdb" + ) from ex + + sql = build_virtual_dataset_query(form_data, limit=limit) + connection = duckdb.connect() + connection.register("data", dataset.table) + try: + result = connection.execute(sql) + result_table = result.arrow() + finally: + connection.close() + + rows = table_to_rows(result_table) + columns = table_to_columns(result_table) + return rows, columns diff --git a/superset/mcp_service/common/error_schemas.py b/superset/mcp_service/common/error_schemas.py index ec0274cc0be8..21bcfcad4269 100644 --- a/superset/mcp_service/common/error_schemas.py +++ b/superset/mcp_service/common/error_schemas.py @@ -50,7 +50,7 @@ class DatasetContext(BaseModel): model_config = {"populate_by_name": True} - id: int = Field(..., description="Dataset ID") + id: int | str = Field(..., description="Dataset ID") table_name: str = Field(..., description="Table name") schema_name: str | None = Field( None, diff --git a/superset/mcp_service/dataframe/__init__.py b/superset/mcp_service/dataframe/__init__.py index 0fa5aaa0a5f5..1b1177bbb016 100644 --- a/superset/mcp_service/dataframe/__init__.py +++ b/superset/mcp_service/dataframe/__init__.py @@ -30,8 +30,16 @@ ) from superset.mcp_service.dataframe.schemas import ( ColumnSchema, + DataFrameSourceCapability, + DataFusionQueryRequest, + DataFusionQueryResponse, + DataFusionSourceConfig, IngestDataFrameRequest, IngestDataFrameResponse, + ListSourceCapabilitiesRequest, + ListSourceCapabilitiesResponse, + PrometheusQueryRequest, + PrometheusQueryResponse, VirtualDatasetInfo, ) @@ -42,5 +50,13 @@ "ColumnSchema", "IngestDataFrameRequest", "IngestDataFrameResponse", + "ListSourceCapabilitiesRequest", + "ListSourceCapabilitiesResponse", + "PrometheusQueryRequest", + "PrometheusQueryResponse", "VirtualDatasetInfo", + "DataFrameSourceCapability", + "DataFusionSourceConfig", + "DataFusionQueryRequest", + "DataFusionQueryResponse", ] diff --git a/superset/mcp_service/dataframe/identifiers.py b/superset/mcp_service/dataframe/identifiers.py new file mode 100644 index 000000000000..e101389c3483 --- /dev/null +++ b/superset/mcp_service/dataframe/identifiers.py @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Helpers for handling MCP virtual dataset identifiers. +""" + +from __future__ import annotations + +VIRTUAL_DATASET_PREFIX = "virtual:" + + +def is_virtual_dataset_identifier(dataset_id: int | str) -> bool: + """Return True when dataset_id uses the virtual:{uuid} prefix.""" + return isinstance(dataset_id, str) and dataset_id.startswith( + VIRTUAL_DATASET_PREFIX + ) + + +def normalize_virtual_dataset_id(dataset_id: str) -> str: + """Normalize dataset IDs by removing the optional virtual: prefix.""" + if dataset_id.startswith(VIRTUAL_DATASET_PREFIX): + return dataset_id[len(VIRTUAL_DATASET_PREFIX) :] + return dataset_id + + +def extract_virtual_dataset_id(dataset_id: int | str) -> str | None: + """Extract raw virtual dataset UUID when dataset_id is prefixed.""" + if not is_virtual_dataset_identifier(dataset_id): + return None + if not isinstance(dataset_id, str): + return None + normalized = normalize_virtual_dataset_id(dataset_id).strip() + return normalized or None + + +def to_virtual_dataset_id(dataset_id: str) -> str: + """Create a virtual:{uuid} identifier from a raw virtual dataset UUID.""" + normalized = normalize_virtual_dataset_id(dataset_id) + return f"{VIRTUAL_DATASET_PREFIX}{normalized}" diff --git a/superset/mcp_service/dataframe/schemas.py b/superset/mcp_service/dataframe/schemas.py index e6c49b8514f9..fb641591b29a 100644 --- a/superset/mcp_service/dataframe/schemas.py +++ b/superset/mcp_service/dataframe/schemas.py @@ -27,7 +27,7 @@ from datetime import datetime from typing import Literal -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator class ColumnSchema(BaseModel): @@ -142,7 +142,8 @@ class IngestDataFrameResponse(BaseModel): virtual_dataset_id: str | None = Field( default=None, description=( - "Prefixed dataset ID in 'virtual:{uuid}' format for use with generate_chart" + "Prefixed dataset ID in 'virtual:{uuid}' format for virtual " + "dataset integrations" ), ) usage_hint: str | None = Field( @@ -152,6 +153,21 @@ class IngestDataFrameResponse(BaseModel): error: str | None = Field(default=None, description="Error message if failed") error_code: str | None = Field(default=None, description="Error code if failed") + @model_validator(mode="after") + def derive_virtual_dataset_id(self) -> "IngestDataFrameResponse": + """ + Populate virtual_dataset_id from dataset_id when omitted. + + This keeps response construction ergonomic while preserving explicit + override behavior when callers set virtual_dataset_id directly. + """ + if self.virtual_dataset_id is None and self.dataset_id: + if self.dataset_id.startswith("virtual:"): + self.virtual_dataset_id = self.dataset_id + else: + self.virtual_dataset_id = f"virtual:{self.dataset_id}" + return self + class ListVirtualDatasetsResponse(BaseModel): """Response from listing virtual datasets.""" @@ -166,7 +182,9 @@ class ListVirtualDatasetsResponse(BaseModel): class RemoveVirtualDatasetRequest(BaseModel): """Request to remove a virtual dataset.""" - dataset_id: str = Field(description="The virtual dataset ID to remove") + dataset_id: str = Field( + description="The virtual dataset ID to remove (raw UUID or virtual:{uuid})" + ) class RemoveVirtualDatasetResponse(BaseModel): @@ -183,7 +201,9 @@ class QueryVirtualDatasetRequest(BaseModel): Allows executing SQL queries against virtual datasets using DuckDB. """ - dataset_id: str = Field(description="The virtual dataset ID to query") + dataset_id: str = Field( + description="The virtual dataset ID to query (raw UUID or virtual:{uuid})" + ) sql: str = Field( description=( "SQL query to execute. The virtual dataset is available " @@ -213,6 +233,268 @@ class QueryVirtualDatasetResponse(BaseModel): error: str | None = Field(default=None, description="Error message if failed") +class PrometheusQueryRequest(BaseModel): + """Request to query Prometheus and optionally register a virtual dataset.""" + + base_url: str = Field( + description="Prometheus base URL, e.g. http://prometheus:9090" + ) + promql: str = Field(description="PromQL query expression", min_length=1) + query_type: Literal["range", "instant"] = Field( + default="range", + description="Prometheus API mode: range or instant query", + ) + start_time: datetime | None = Field( + default=None, + description="Range query start time (UTC if naive); defaults to 1 hour ago", + ) + end_time: datetime | None = Field( + default=None, + description="Range query end time (UTC if naive); defaults to now", + ) + step_seconds: int = Field( + default=60, + ge=1, + le=86400, + description="Range query step size in seconds", + ) + timeout_seconds: int = Field( + default=30, + ge=1, + le=300, + description="HTTP timeout for Prometheus request", + ) + verify_ssl: bool = Field(default=True, description="Verify TLS certificates") + ingest_as_virtual_dataset: bool = Field( + default=True, + description="Register query results as an MCP virtual dataset", + ) + dataset_name: str | None = Field( + default=None, + min_length=1, + max_length=100, + description="Optional name override for ingested virtual dataset", + ) + ttl_minutes: int = Field( + default=60, + ge=0, + le=1440, + description="TTL for ingested virtual dataset when ingestion is enabled", + ) + allow_cross_session: bool = Field( + default=False, + description="Allow same-user cross-session access to ingested dataset", + ) + + @model_validator(mode="after") + def validate_temporal_window(self) -> "PrometheusQueryRequest": + """Validate time-range constraints for range queries.""" + if self.query_type == "range": + if self.start_time and self.end_time and self.start_time > self.end_time: + raise ValueError("start_time must be less than or equal to end_time") + return self + + +class PrometheusQueryResponse(BaseModel): + """Response from querying Prometheus.""" + + success: bool = Field(description="Whether query execution succeeded") + result_type: str | None = Field( + default=None, + description="Prometheus result type, e.g. matrix or vector", + ) + rows: list[dict[str, str | int | float | bool | None]] | None = Field( + default=None, + description="Flattened Prometheus result rows", + ) + columns: list[dict[str, str]] | None = Field( + default=None, + description="Column metadata for flattened rows", + ) + row_count: int | None = Field(default=None, description="Number of rows returned") + dataset: VirtualDatasetInfo | None = Field( + default=None, + description="Virtual dataset info when ingestion is enabled", + ) + dataset_id: str | None = Field( + default=None, + description="Raw virtual dataset UUID when ingestion is enabled", + ) + virtual_dataset_id: str | None = Field( + default=None, + description="Prefixed virtual dataset ID when ingestion is enabled", + ) + source_capabilities: list["DataFrameSourceCapability"] = Field( + default_factory=list, + description="Capabilities for source adapters used during execution", + ) + warning: str | None = Field( + default=None, + description="Non-fatal warning message", + ) + error: str | None = Field(default=None, description="Error message if failed") + error_code: str | None = Field(default=None, description="Error code if failed") + + +class DataFusionSourceConfig(BaseModel): + """Source registration entry for DataFusion query execution.""" + + name: str = Field( + description="SQL table name to register in DataFusion", + min_length=1, + ) + source_type: Literal["parquet", "arrow_ipc", "virtual_dataset"] = Field( + description="Data source type to register" + ) + path: str | None = Field( + default=None, + description="Filesystem path or URI for parquet source_type", + ) + data: str | None = Field( + default=None, + description="Base64 Arrow IPC payload for arrow_ipc source_type", + ) + dataset_id: str | None = Field( + default=None, + description="Virtual dataset ID for virtual_dataset source_type", + ) + + @model_validator(mode="after") + def validate_required_source_fields(self) -> "DataFusionSourceConfig": + """Validate per-source required fields by source_type.""" + if self.source_type == "parquet" and not self.path: + raise ValueError("path is required when source_type='parquet'") + if self.source_type == "arrow_ipc" and not self.data: + raise ValueError("data is required when source_type='arrow_ipc'") + if self.source_type == "virtual_dataset" and not self.dataset_id: + raise ValueError( + "dataset_id is required when source_type='virtual_dataset'" + ) + return self + + +class DataFrameSourceCapability(BaseModel): + """Capability metadata for dataframe source adapters.""" + + source_type: str = Field(description="Source type identifier") + adapter_name: str = Field(description="Adapter class name") + supports_streaming: bool = Field( + description="Whether source supports streaming ingestion/read patterns" + ) + supports_projection_pushdown: bool = Field( + description="Whether source supports column projection pushdown" + ) + supports_predicate_pushdown: bool = Field( + description="Whether source supports predicate/filter pushdown" + ) + supports_sql_pushdown: bool = Field( + description="Whether source can execute SQL in source-native runtime" + ) + supports_virtual_dataset_ingestion: bool = Field( + description="Whether source output can be ingested as MCP virtual datasets" + ) + + +class ListSourceCapabilitiesRequest(BaseModel): + """Request to list dataframe source adapter capabilities.""" + + source_types: list[str] | None = Field( + default=None, + description=( + "Optional source types to include. Defaults to all known source types." + ), + ) + include_prometheus: bool = Field( + default=True, + description="Include Prometheus HTTP source capability metadata", + ) + + +class ListSourceCapabilitiesResponse(BaseModel): + """Response listing dataframe source capability metadata.""" + + success: bool = Field(description="Whether capability discovery succeeded") + capabilities: list[DataFrameSourceCapability] = Field( + default_factory=list, + description="Capability metadata entries", + ) + total_count: int = Field(description="Number of returned capability entries") + error: str | None = Field(default=None, description="Error message if failed") + + +class DataFusionQueryRequest(BaseModel): + """Request to execute SQL via DataFusion against Parquet/Arrow sources.""" + + sql: str = Field(description="SQL query to execute via DataFusion", min_length=1) + sources: list[DataFusionSourceConfig] = Field( + min_length=1, + description="One or more source registrations for the DataFusion context", + ) + limit: int = Field( + default=1000, + ge=1, + le=10000, + description="Maximum number of rows to return", + ) + ingest_result: bool = Field( + default=False, + description="Register query results as a virtual dataset", + ) + result_dataset_name: str | None = Field( + default=None, + min_length=1, + max_length=100, + description="Optional name override for ingested query results", + ) + ttl_minutes: int = Field( + default=60, + ge=0, + le=1440, + description="TTL for ingested result dataset when ingest_result is enabled", + ) + allow_cross_session: bool = Field( + default=False, + description="Allow same-user cross-session access to ingested result dataset", + ) + + +class DataFusionQueryResponse(BaseModel): + """Response from executing a DataFusion query.""" + + success: bool = Field(description="Whether query execution succeeded") + rows: list[dict[str, str | int | float | bool | None]] | None = Field( + default=None, + description="Result rows from DataFusion query", + ) + columns: list[dict[str, str]] | None = Field( + default=None, + description="Column metadata for query results", + ) + row_count: int | None = Field(default=None, description="Number of rows returned") + dataset: VirtualDatasetInfo | None = Field( + default=None, + description="Virtual dataset info when ingest_result is enabled", + ) + dataset_id: str | None = Field( + default=None, + description="Raw virtual dataset UUID when ingest_result is enabled", + ) + virtual_dataset_id: str | None = Field( + default=None, + description="Prefixed virtual dataset ID when ingest_result is enabled", + ) + source_capabilities: list[DataFrameSourceCapability] = Field( + default_factory=list, + description="Capabilities for source adapters used during execution", + ) + warning: str | None = Field( + default=None, + description="Non-fatal warning message", + ) + error: str | None = Field(default=None, description="Error message if failed") + error_code: str | None = Field(default=None, description="Error code if failed") + + class DataFrameAnalysisResult(BaseModel): """ Analysis results for a DataFrame. @@ -282,3 +564,6 @@ class AnalyzeDataFrameResponse(BaseModel): default_factory=list, description="Chart recommendations" ) error: str | None = Field(default=None, description="Error message if failed") + + +PrometheusQueryResponse.model_rebuild() diff --git a/superset/mcp_service/dataframe/tool/__init__.py b/superset/mcp_service/dataframe/tool/__init__.py index 7ae5c38937a5..14c8c2587d59 100644 --- a/superset/mcp_service/dataframe/tool/__init__.py +++ b/superset/mcp_service/dataframe/tool/__init__.py @@ -22,9 +22,14 @@ """ from superset.mcp_service.dataframe.tool.ingest_dataframe import ingest_dataframe +from superset.mcp_service.dataframe.tool.list_source_capabilities import ( + list_source_capabilities, +) from superset.mcp_service.dataframe.tool.list_virtual_datasets import ( list_virtual_datasets, ) +from superset.mcp_service.dataframe.tool.query_datafusion import query_datafusion +from superset.mcp_service.dataframe.tool.query_prometheus import query_prometheus from superset.mcp_service.dataframe.tool.query_virtual_dataset import ( query_virtual_dataset, ) @@ -34,7 +39,10 @@ __all__ = [ "ingest_dataframe", + "list_source_capabilities", "list_virtual_datasets", + "query_datafusion", + "query_prometheus", "query_virtual_dataset", "remove_virtual_dataset", ] diff --git a/superset/mcp_service/dataframe/tool/common.py b/superset/mcp_service/dataframe/tool/common.py new file mode 100644 index 000000000000..62a94f7cf423 --- /dev/null +++ b/superset/mcp_service/dataframe/tool/common.py @@ -0,0 +1,127 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import math +import re +from datetime import date, datetime, time +from decimal import Decimal +from typing import Any + +import pyarrow as pa + +from superset.mcp_service.dataframe.registry import VirtualDataset +from superset.mcp_service.dataframe.schemas import VirtualDatasetInfo + +UNSAFE_SQL_KEYWORDS = re.compile( + r"\b(" + r"ALTER|ATTACH|CALL|COPY|CREATE|DELETE|DETACH|DROP|EXPORT|IMPORT|INSTALL|" + r"INSERT|LOAD|PRAGMA|SET|UPDATE|VACUUM" + r")\b", + re.IGNORECASE, +) + + +def normalize_sql(sql: str) -> str: + """Normalize SQL for safe wrapping and execution.""" + normalized_sql = sql.strip() + if normalized_sql.endswith(";"): + normalized_sql = normalized_sql[:-1].rstrip() + return normalized_sql + + +def validate_read_only_sql(sql: str) -> str | None: + """Validate SQL text to allow only single-statement read-only queries.""" + normalized_sql = normalize_sql(sql) + if not normalized_sql: + return "SQL query cannot be empty" + if ";" in normalized_sql: + return "Only single-statement SQL queries are allowed" + + normalized_upper = normalized_sql.upper() + if not ( + normalized_upper.startswith("SELECT") + or normalized_upper.startswith("WITH") + ): + return "Only SELECT and WITH queries are allowed" + if UNSAFE_SQL_KEYWORDS.search(normalized_sql): + return "Query contains non-read-only SQL keywords" + + return None + + +def convert_value(value: Any) -> str | int | float | bool | None: + """Convert a scalar to a JSON-safe primitive representation.""" + if value is None: + return None + if isinstance(value, (str, int, float, bool)): + return value + if isinstance(value, (datetime, date, time)): + return value.isoformat() + if isinstance(value, Decimal): + numeric = float(value) + if math.isfinite(numeric): + return numeric + return str(value) + + try: + numeric = float(value) + if math.isfinite(numeric): + return numeric + except (TypeError, ValueError): + pass + + return str(value) + + +def table_to_rows( + table: pa.Table, +) -> list[dict[str, str | int | float | bool | None]]: + """Convert an Arrow table into JSON-safe rows.""" + raw_rows = table.to_pylist() + return [ + { + column_name: convert_value(column_value) + for column_name, column_value in row.items() + } + for row in raw_rows + ] + + +def table_to_columns(table: pa.Table) -> list[dict[str, str]]: + """Build typed column metadata from an Arrow table schema.""" + return [{"name": field.name, "type": str(field.type)} for field in table.schema] + + +def build_virtual_dataset_info( + dataset: VirtualDataset, + description: str | None = None, +) -> VirtualDatasetInfo: + """Build a response model for a registered virtual dataset.""" + return VirtualDatasetInfo( + id=dataset.id, + name=dataset.name, + row_count=dataset.row_count, + column_count=len(dataset.column_names), + size_bytes=dataset.size_bytes, + size_mb=round(dataset.size_bytes / 1024 / 1024, 2), + created_at=dataset.created_at, + expires_at=dataset.expires_at, + columns=dataset.get_column_info(), + description=description, + ) diff --git a/superset/mcp_service/dataframe/tool/context.py b/superset/mcp_service/dataframe/tool/context.py new file mode 100644 index 000000000000..d3197880d4c4 --- /dev/null +++ b/superset/mcp_service/dataframe/tool/context.py @@ -0,0 +1,41 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from fastmcp import Context + +from superset.utils.core import get_user_id + + +def resolve_session_and_user(ctx: Context) -> tuple[str | None, int | None]: + """ + Resolve session and user identity for DataFrame virtual dataset access. + + Falls back to a deterministic per-user session identifier when FastMCP + session context is unavailable. + """ + session_id = getattr(ctx, "session_id", None) + try: + user_id = get_user_id() + except Exception: + user_id = None + + if not session_id and user_id is not None: + session_id = f"user_{user_id}" + + return session_id, user_id diff --git a/superset/mcp_service/dataframe/tool/ingest_dataframe.py b/superset/mcp_service/dataframe/tool/ingest_dataframe.py index 969cdefd30a9..0bb59f7783bb 100644 --- a/superset/mcp_service/dataframe/tool/ingest_dataframe.py +++ b/superset/mcp_service/dataframe/tool/ingest_dataframe.py @@ -38,8 +38,8 @@ IngestDataFrameResponse, VirtualDatasetInfo, ) +from superset.mcp_service.dataframe.tool.context import resolve_session_and_user from superset.mcp_service.utils.schema_utils import parse_request -from superset.utils.core import get_user_id logger = logging.getLogger(__name__) @@ -54,10 +54,10 @@ async def ingest_dataframe( This tool allows AI agents to upload DataFrame data directly without requiring database storage. The data is registered as a virtual dataset - that can be used with generate_chart and other visualization tools. + that can be queried and managed via DataFrame MCP tools. - IMPORTANT: Use 'virtual:{dataset_id}' format when referencing the - dataset in generate_chart or other tools. + IMPORTANT: Store the returned dataset_id for follow-up calls to + query_virtual_dataset, list_virtual_datasets, and remove_virtual_dataset. Example workflow (Python): ```python @@ -84,15 +84,10 @@ async def ingest_dataframe( "ttl_minutes": 60 }) - # 4. Use with generate_chart - chart = await generate_chart({ - "dataset_id": f"virtual:{result.dataset_id}", - "config": { - "chart_type": "xy", - "x": {"name": "date"}, - "y": [{"name": "sales", "aggregate": "SUM"}], - "kind": "line" - } + # 4. Query with virtual dataset tools + rows = await query_virtual_dataset({ + "dataset_id": result.dataset_id, + "sql": "SELECT region, SUM(sales) AS total FROM data GROUP BY region" }) ``` @@ -146,26 +141,17 @@ async def ingest_dataframe( % (table.num_rows, table.num_columns) ) - # Get user and session IDs - try: - user_id = get_user_id() - except Exception: - user_id = None - - session_id = getattr(ctx, "session_id", None) + # Resolve identity context for access control. + session_id, user_id = resolve_session_and_user(ctx) if not session_id: - if user_id is not None: - # Derive a per-user fallback session ID to avoid cross-user collisions - session_id = f"user_{user_id}" - else: - logger.error( - "Missing both session_id and user_id; refusing to ingest DataFrame" - ) - return IngestDataFrameResponse( - success=False, - error="Missing session and user context; cannot safely ingest DataFrame", - error_code="MISSING_SESSION_CONTEXT", - ) + logger.error( + "Missing both session_id and user_id; refusing to ingest DataFrame" + ) + return IngestDataFrameResponse( + success=False, + error="Missing session and user context; cannot safely ingest DataFrame", + error_code="MISSING_SESSION_CONTEXT", + ) # Calculate TTL ttl = ( timedelta(minutes=request.ttl_minutes) @@ -231,8 +217,10 @@ async def ingest_dataframe( dataset_id=dataset_id, virtual_dataset_id=virtual_dataset_id, usage_hint=( - f"Use '{virtual_dataset_id}' as the dataset_id in generate_chart " - "or other visualization tools." + f"Use '{dataset_id}' with query_virtual_dataset and " + "remove_virtual_dataset. Keep " + f"'{virtual_dataset_id}' for integrations that require " + "prefixed virtual dataset IDs." ), ) diff --git a/superset/mcp_service/dataframe/tool/list_source_capabilities.py b/superset/mcp_service/dataframe/tool/list_source_capabilities.py new file mode 100644 index 000000000000..fa7490c60fc0 --- /dev/null +++ b/superset/mcp_service/dataframe/tool/list_source_capabilities.py @@ -0,0 +1,107 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +List DataFrame Source Capabilities MCP Tool + +Returns discoverability metadata for supported dataframe source adapters. +""" + +from __future__ import annotations + +import logging + +from fastmcp import Context +from superset_core.mcp import tool + +from superset.mcp_service.dataframe.schemas import ( + DataFrameSourceCapability, + ListSourceCapabilitiesRequest, + ListSourceCapabilitiesResponse, +) +from superset.mcp_service.dataframe.tool.source_adapters import ( + PROMETHEUS_SOURCE_CAPABILITY, + list_datafusion_source_capabilities, +) +from superset.mcp_service.utils.schema_utils import parse_request + +logger = logging.getLogger(__name__) + +PROMETHEUS_SOURCE_TYPE = "prometheus_http" + + +@tool(tags=["dataframe", "metadata"]) +@parse_request(ListSourceCapabilitiesRequest) +async def list_source_capabilities( + request: ListSourceCapabilitiesRequest, ctx: Context +) -> ListSourceCapabilitiesResponse: + """ + List source capability metadata for dataframe tools. + + This helps agents decide which connector path to use based on support for: + - streaming + - projection/predicate pushdown + - SQL pushdown + - virtual dataset ingestion + """ + await ctx.info("Listing dataframe source capabilities") + + try: + requested_types = ( + {value.strip() for value in request.source_types if value.strip()} + if request.source_types + else None + ) + + datafusion_filter = None + if requested_types is not None: + datafusion_filter = [ + source_type + for source_type in requested_types + if source_type != PROMETHEUS_SOURCE_TYPE + ] + + capabilities = [ + DataFrameSourceCapability.model_validate(capability.__dict__) + for capability in list_datafusion_source_capabilities(datafusion_filter) + ] + + include_prometheus = request.include_prometheus and ( + requested_types is None or PROMETHEUS_SOURCE_TYPE in requested_types + ) + if include_prometheus: + capabilities.append( + DataFrameSourceCapability.model_validate( + PROMETHEUS_SOURCE_CAPABILITY.__dict__ + ) + ) + + await ctx.info("Listed %d source capabilities" % len(capabilities)) + return ListSourceCapabilitiesResponse( + success=True, + capabilities=capabilities, + total_count=len(capabilities), + ) + except Exception as ex: + logger.exception("Failed to list source capabilities: %s", ex) + await ctx.error("Failed to list source capabilities: %s" % str(ex)) + return ListSourceCapabilitiesResponse( + success=False, + capabilities=[], + total_count=0, + error=str(ex), + ) diff --git a/superset/mcp_service/dataframe/tool/list_virtual_datasets.py b/superset/mcp_service/dataframe/tool/list_virtual_datasets.py index da0c126fe7d4..c4d53372e4fb 100644 --- a/superset/mcp_service/dataframe/tool/list_virtual_datasets.py +++ b/superset/mcp_service/dataframe/tool/list_virtual_datasets.py @@ -34,7 +34,7 @@ ListVirtualDatasetsResponse, VirtualDatasetInfo, ) -from superset.utils.core import get_user_id +from superset.mcp_service.dataframe.tool.context import resolve_session_and_user logger = logging.getLogger(__name__) @@ -58,12 +58,18 @@ async def list_virtual_datasets(ctx: Context) -> ListVirtualDatasetsResponse: await ctx.info("Listing virtual datasets") try: - # Get session ID from context - session_id = getattr(ctx, "session_id", None) or "default_session" - try: - user_id = get_user_id() - except Exception: - user_id = None + session_id, user_id = resolve_session_and_user(ctx) + if not session_id and user_id is None: + message = ( + "Missing session and user context; cannot safely list " + "virtual datasets" + ) + await ctx.error(message) + return ListVirtualDatasetsResponse( + datasets=[], + total_count=0, + total_size_mb=0.0, + ) # Get registry and list datasets registry = get_registry() diff --git a/superset/mcp_service/dataframe/tool/query_datafusion.py b/superset/mcp_service/dataframe/tool/query_datafusion.py new file mode 100644 index 000000000000..434ca01f7e5f --- /dev/null +++ b/superset/mcp_service/dataframe/tool/query_datafusion.py @@ -0,0 +1,185 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +DataFusion Query MCP Tool + +Execute SQL against DataFusion with Parquet, Arrow IPC, and virtual dataset +sources. +""" + +from __future__ import annotations + +import logging +from datetime import timedelta + +import pyarrow as pa +from fastmcp import Context +from superset_core.mcp import tool + +from superset.mcp_service.dataframe.identifiers import to_virtual_dataset_id +from superset.mcp_service.dataframe.registry import get_registry +from superset.mcp_service.dataframe.schemas import ( + DataFrameSourceCapability, + DataFusionQueryRequest, + DataFusionQueryResponse, +) +from superset.mcp_service.dataframe.tool.common import ( + build_virtual_dataset_info, + normalize_sql, + table_to_columns, + table_to_rows, + validate_read_only_sql, +) +from superset.mcp_service.dataframe.tool.context import resolve_session_and_user +from superset.mcp_service.dataframe.tool.source_adapters import ( + get_datafusion_source_adapter, +) +from superset.mcp_service.utils.schema_utils import parse_request + +logger = logging.getLogger(__name__) + + +@tool(tags=["dataframe", "source"]) +@parse_request(DataFusionQueryRequest) +async def query_datafusion( + request: DataFusionQueryRequest, + ctx: Context, +) -> DataFusionQueryResponse: + """Execute SQL using DataFusion across registered Parquet/Arrow sources.""" + await ctx.info( + "Executing DataFusion query: sources=%d, ingest=%s" + % (len(request.sources), request.ingest_result) + ) + + try: + try: + from datafusion import SessionContext + except ImportError: + return DataFusionQueryResponse( + success=False, + error=( + "DataFusion is required for this tool. Install with " + "`pip install datafusion`." + ), + error_code="DATAFUSION_NOT_INSTALLED", + ) + + validation_error = validate_read_only_sql(request.sql) + if validation_error: + return DataFusionQueryResponse( + success=False, + error=validation_error, + error_code="INVALID_SQL", + ) + + session_id, user_id = resolve_session_and_user(ctx) + registry = get_registry() + source_capabilities: list[DataFrameSourceCapability] = [] + + session_ctx = SessionContext() + for source in request.sources: + adapter = get_datafusion_source_adapter(source.source_type) + if adapter is None: + return DataFusionQueryResponse( + success=False, + error=f"Unsupported source_type '{source.source_type}'", + error_code="UNSUPPORTED_SOURCE_TYPE", + source_capabilities=source_capabilities, + ) + try: + adapter.register_source( + session_ctx=session_ctx, + source=source, + registry=registry, + session_id=session_id, + user_id=user_id, + ) + except ValueError as ex: + error_code = "INVALID_SOURCE_CONFIG" + if source.source_type == "virtual_dataset": + error_code = "VIRTUAL_DATASET_NOT_FOUND" + return DataFusionQueryResponse( + success=False, + error=str(ex), + error_code=error_code, + source_capabilities=source_capabilities, + ) + source_capabilities.append( + DataFrameSourceCapability.model_validate(adapter.capability.__dict__) + ) + + normalized_sql = normalize_sql(request.sql) + sql = f"SELECT * FROM ({normalized_sql}) AS subq LIMIT {request.limit}" + await ctx.debug("Executing DataFusion SQL: %s" % sql[:200]) + + query_df = session_ctx.sql(sql) + result_batches = query_df.collect() + result_table = ( + pa.Table.from_batches(result_batches) + if result_batches + else pa.Table.from_pylist([]) + ) + + rows = table_to_rows(result_table) + columns = table_to_columns(result_table) if result_table.num_columns > 0 else [] + + response = DataFusionQueryResponse( + success=True, + rows=rows, + columns=columns, + row_count=len(rows), + source_capabilities=source_capabilities, + ) + + if request.ingest_result: + if not session_id: + response.warning = ( + "DataFusion query succeeded, but result ingestion was skipped due " + "to missing session/user context" + ) + return response + + result_name = request.result_dataset_name or "datafusion_query_result" + ttl = ( + timedelta(minutes=request.ttl_minutes) + if request.ttl_minutes > 0 + else timedelta(seconds=0) + ) + dataset_id = registry.register( + name=result_name, + table=result_table, + session_id=session_id, + user_id=user_id, + ttl=ttl, + allow_cross_session=request.allow_cross_session, + ) + dataset = registry.get(dataset_id, session_id=session_id, user_id=user_id) + if dataset is not None: + response.dataset = build_virtual_dataset_info(dataset) + response.dataset_id = dataset_id + response.virtual_dataset_id = to_virtual_dataset_id(dataset_id) + + return response + except Exception as ex: + logger.exception("DataFusion query failed: %s", ex) + await ctx.error("DataFusion query failed: %s" % str(ex)) + return DataFusionQueryResponse( + success=False, + error=f"Unexpected error: {ex}", + error_code="UNEXPECTED_ERROR", + ) diff --git a/superset/mcp_service/dataframe/tool/query_prometheus.py b/superset/mcp_service/dataframe/tool/query_prometheus.py new file mode 100644 index 000000000000..41e877482b76 --- /dev/null +++ b/superset/mcp_service/dataframe/tool/query_prometheus.py @@ -0,0 +1,281 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Prometheus Query MCP Tool + +Query Prometheus via HTTP API and optionally ingest results as a virtual dataset. +""" + +from __future__ import annotations + +import logging +import math +import ssl +from datetime import datetime, timedelta, timezone +from typing import Any +from urllib.error import HTTPError, URLError +from urllib.parse import urlencode +from urllib.request import Request, urlopen + +import pyarrow as pa +from fastmcp import Context +from superset_core.mcp import tool + +from superset.mcp_service.dataframe.identifiers import to_virtual_dataset_id +from superset.mcp_service.dataframe.registry import get_registry +from superset.mcp_service.dataframe.schemas import ( + DataFrameSourceCapability, + PrometheusQueryRequest, + PrometheusQueryResponse, +) +from superset.mcp_service.dataframe.tool.common import ( + build_virtual_dataset_info, + convert_value, + table_to_columns, +) +from superset.mcp_service.dataframe.tool.context import resolve_session_and_user +from superset.mcp_service.dataframe.tool.source_adapters import ( + PROMETHEUS_SOURCE_CAPABILITY, +) +from superset.mcp_service.utils.schema_utils import parse_request +from superset.utils import json + +logger = logging.getLogger(__name__) + + +def _to_timestamp(value: datetime) -> float: + """Convert datetime to a UTC unix timestamp.""" + if value.tzinfo is None: + value = value.replace(tzinfo=timezone.utc) + return value.astimezone(timezone.utc).timestamp() + + +def _format_timestamp(value: Any) -> str: + """Format Prometheus timestamp values as ISO-8601 UTC strings.""" + try: + ts = float(value) + return datetime.fromtimestamp(ts, tz=timezone.utc).isoformat() + except (TypeError, ValueError): + return str(value) + + +def _parse_prometheus_scalar(value: Any) -> str | int | float | bool | None: + """Parse Prometheus scalar text values into JSON-safe primitives.""" + parsed = convert_value(value) + if isinstance(parsed, float): + if not math.isfinite(parsed): + return str(value) + return parsed + + +def _flatten_prometheus_result( + result_type: str, + result: Any, +) -> list[dict[str, str | int | float | bool | None]]: + """Flatten Prometheus result payload into row records.""" + rows: list[dict[str, str | int | float | bool | None]] = [] + if result_type == "matrix": + if not isinstance(result, list): + return rows + for series in result: + if not isinstance(series, dict): + continue + metric = series.get("metric", {}) + metric_labels = {str(k): str(v) for k, v in metric.items()} + for point in series.get("values", []): + if not isinstance(point, (list, tuple)) or len(point) != 2: + continue + rows.append( + { + **metric_labels, + "timestamp": _format_timestamp(point[0]), + "value": _parse_prometheus_scalar(point[1]), + } + ) + return rows + + if result_type == "vector": + if not isinstance(result, list): + return rows + for series in result: + if not isinstance(series, dict): + continue + metric = series.get("metric", {}) + metric_labels = {str(k): str(v) for k, v in metric.items()} + point = series.get("value", []) + if isinstance(point, (list, tuple)) and len(point) == 2: + rows.append( + { + **metric_labels, + "timestamp": _format_timestamp(point[0]), + "value": _parse_prometheus_scalar(point[1]), + } + ) + return rows + + if result_type in {"scalar", "string"}: + if isinstance(result, (list, tuple)) and len(result) == 2: + rows.append( + { + "timestamp": _format_timestamp(result[0]), + "value": _parse_prometheus_scalar(result[1]), + } + ) + return rows + + return rows + + +@tool(tags=["dataframe", "source"]) +@parse_request(PrometheusQueryRequest) +async def query_prometheus( + request: PrometheusQueryRequest, + ctx: Context, +) -> PrometheusQueryResponse: + """ + Execute a PromQL query against Prometheus. + + For range queries, results are flattened into one row per sample with label + columns, a UTC timestamp column, and a value column. + """ + await ctx.info( + "Querying Prometheus: query_type=%s, ingest=%s" + % (request.query_type, request.ingest_as_virtual_dataset) + ) + + try: + base_url = request.base_url.rstrip("/") + if request.query_type == "range": + endpoint = f"{base_url}/api/v1/query_range" + end_time = request.end_time or datetime.now(timezone.utc) + start_time = request.start_time or (end_time - timedelta(hours=1)) + params = { + "query": request.promql, + "start": _to_timestamp(start_time), + "end": _to_timestamp(end_time), + "step": request.step_seconds, + } + else: + endpoint = f"{base_url}/api/v1/query" + time_value = request.end_time or datetime.now(timezone.utc) + params = { + "query": request.promql, + "time": _to_timestamp(time_value), + } + + params["timeout"] = f"{request.timeout_seconds}s" + url = f"{endpoint}?{urlencode(params)}" + + await ctx.debug("Prometheus request URL prepared") + + ssl_context = ssl.create_default_context() + if not request.verify_ssl: + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + + req = Request(url, headers={"Accept": "application/json"}) + try: + with urlopen( + req, + timeout=request.timeout_seconds, + context=ssl_context, + ) as response: + payload = json.loads(response.read().decode("utf-8")) + except HTTPError as ex: + logger.error("Prometheus HTTP error: %s", ex) + return PrometheusQueryResponse( + success=False, + error=f"Prometheus HTTP error: {ex.code}", + error_code="PROMETHEUS_HTTP_ERROR", + ) + except URLError as ex: + logger.error("Prometheus connection error: %s", ex) + return PrometheusQueryResponse( + success=False, + error=f"Prometheus connection error: {ex}", + error_code="PROMETHEUS_CONNECTION_ERROR", + ) + + if payload.get("status") != "success": + error_message = payload.get("error") or "Prometheus query failed" + return PrometheusQueryResponse( + success=False, + error=str(error_message), + error_code="PROMETHEUS_QUERY_FAILED", + ) + + data = payload.get("data", {}) + result_type = str(data.get("resultType") or "") + result = data.get("result", []) + + rows = _flatten_prometheus_result(result_type, result) + table = pa.Table.from_pylist(rows) if rows else pa.Table.from_pylist([]) + columns = table_to_columns(table) if rows else [] + + response = PrometheusQueryResponse( + success=True, + result_type=result_type, + rows=rows, + columns=columns, + row_count=len(rows), + source_capabilities=[ + DataFrameSourceCapability.model_validate( + PROMETHEUS_SOURCE_CAPABILITY.__dict__ + ) + ], + ) + + if request.ingest_as_virtual_dataset: + session_id, user_id = resolve_session_and_user(ctx) + if not session_id: + response.warning = ( + "Prometheus query succeeded, but ingestion was skipped due to " + "missing session/user context" + ) + return response + + dataset_name = request.dataset_name or "prometheus_query_result" + ttl = ( + timedelta(minutes=request.ttl_minutes) + if request.ttl_minutes > 0 + else timedelta(seconds=0) + ) + registry = get_registry() + dataset_id = registry.register( + name=dataset_name, + table=table, + session_id=session_id, + user_id=user_id, + ttl=ttl, + allow_cross_session=request.allow_cross_session, + ) + dataset = registry.get(dataset_id, session_id=session_id, user_id=user_id) + if dataset is not None: + response.dataset = build_virtual_dataset_info(dataset) + response.dataset_id = dataset_id + response.virtual_dataset_id = to_virtual_dataset_id(dataset_id) + + return response + except Exception as ex: + logger.exception("Prometheus query tool failed: %s", ex) + await ctx.error("Prometheus query failed: %s" % str(ex)) + return PrometheusQueryResponse( + success=False, + error=f"Unexpected error: {ex}", + error_code="UNEXPECTED_ERROR", + ) diff --git a/superset/mcp_service/dataframe/tool/query_virtual_dataset.py b/superset/mcp_service/dataframe/tool/query_virtual_dataset.py index ebdcab98a395..902b1917dbdb 100644 --- a/superset/mcp_service/dataframe/tool/query_virtual_dataset.py +++ b/superset/mcp_service/dataframe/tool/query_virtual_dataset.py @@ -24,38 +24,27 @@ from __future__ import annotations import logging -from typing import Any from fastmcp import Context from superset_core.mcp import tool +from superset.mcp_service.dataframe.identifiers import normalize_virtual_dataset_id from superset.mcp_service.dataframe.registry import get_registry from superset.mcp_service.dataframe.schemas import ( QueryVirtualDatasetRequest, QueryVirtualDatasetResponse, ) +from superset.mcp_service.dataframe.tool.common import ( + normalize_sql, + table_to_columns, + table_to_rows, + validate_read_only_sql, +) +from superset.mcp_service.dataframe.tool.context import resolve_session_and_user from superset.mcp_service.utils.schema_utils import parse_request -from superset.utils.core import get_user_id logger = logging.getLogger(__name__) - -def _convert_value(value: Any) -> str | int | float | bool | None: - """Convert a value to a JSON-serializable type.""" - if value is None: - return None - if isinstance(value, (str, int, float, bool)): - return value - # Handle numpy types and other types - try: - # Try to convert to float first (handles numpy numeric types) - return float(value) - except (TypeError, ValueError): - pass - # Fall back to string conversion - return str(value) - - @tool(tags=["dataframe"]) @parse_request(QueryVirtualDatasetRequest) async def query_virtual_dataset( @@ -73,7 +62,7 @@ async def query_virtual_dataset( - SELECT date, sales FROM data WHERE sales > 100 ORDER BY date Args: - dataset_id: The virtual dataset ID (without 'virtual:' prefix) + dataset_id: The virtual dataset ID (raw UUID or virtual:{uuid}) sql: SQL query to execute limit: Maximum rows to return (default 1000) @@ -90,32 +79,25 @@ async def query_virtual_dataset( registry = get_registry() # Determine session and user context - session_id = getattr(ctx, "session_id", None) - try: - user_id = get_user_id() - except Exception: - user_id = None + session_id, user_id = resolve_session_and_user(ctx) - # Avoid falling back to a shared default session identifier if not session_id: - if user_id is not None: - # Derive a per-user session identifier to prevent collisions - session_id = f"user_session_{user_id}" - else: - logger.warning( - "Missing session_id and user_id; refusing to query virtual dataset " - "for dataset_id=%s", - request.dataset_id, - ) - return QueryVirtualDatasetResponse( - success=False, - error=( - "Cannot query virtual dataset: missing session and user " - "context. Please retry after re-authenticating." - ), - ) + logger.warning( + "Missing session_id and user_id; refusing to query virtual dataset " + "for dataset_id=%s", + request.dataset_id, + ) + return QueryVirtualDatasetResponse( + success=False, + error=( + "Cannot query virtual dataset: missing session and user " + "context. Please retry after re-authenticating." + ), + ) dataset = registry.get( - request.dataset_id, session_id=session_id, user_id=user_id + normalize_virtual_dataset_id(request.dataset_id), + session_id=session_id, + user_id=user_id, ) if dataset is None: @@ -140,25 +122,18 @@ async def query_virtual_dataset( ), ) + validation_error = validate_read_only_sql(request.sql) + if validation_error: + return QueryVirtualDatasetResponse(success=False, error=validation_error) + + # Always enforce an outer LIMIT to cap the result size. + normalized_sql = normalize_sql(request.sql) + sql = f"SELECT * FROM ({normalized_sql}) AS subq LIMIT {request.limit}" + # Create DuckDB connection and register the table conn = duckdb.connect() conn.register("data", dataset.table) - # Execute query with limit - sql = request.sql.strip() - - # Add LIMIT if not present (safety measure) - sql_upper = sql.upper() - if "LIMIT" not in sql_upper: - sql = f"{sql} LIMIT {request.limit}" - - # Remove a trailing semicolon so the query can be safely wrapped - if sql.endswith(";"): - sql = sql[:-1].rstrip() - - # Always enforce an outer LIMIT to cap the result size - sql = f"SELECT * FROM ({sql}) AS subq" - await ctx.debug("Executing SQL: %s" % sql[:200]) try: @@ -174,12 +149,8 @@ async def query_virtual_dataset( conn.close() # Convert to response format - df = result_table.to_pandas() - rows = [] - for _, row in df.iterrows(): - rows.append({col: _convert_value(row[col]) for col in df.columns}) - - columns = [{"name": col, "type": str(df[col].dtype)} for col in df.columns] + rows = table_to_rows(result_table) + columns = table_to_columns(result_table) await ctx.info( "Query completed: rows=%d, columns=%d" % (len(rows), len(columns)) diff --git a/superset/mcp_service/dataframe/tool/remove_virtual_dataset.py b/superset/mcp_service/dataframe/tool/remove_virtual_dataset.py index c25bfb6c8713..613993429c7a 100644 --- a/superset/mcp_service/dataframe/tool/remove_virtual_dataset.py +++ b/superset/mcp_service/dataframe/tool/remove_virtual_dataset.py @@ -28,13 +28,14 @@ from fastmcp import Context from superset_core.mcp import tool +from superset.mcp_service.dataframe.identifiers import normalize_virtual_dataset_id from superset.mcp_service.dataframe.registry import get_registry from superset.mcp_service.dataframe.schemas import ( RemoveVirtualDatasetRequest, RemoveVirtualDatasetResponse, ) +from superset.mcp_service.dataframe.tool.context import resolve_session_and_user from superset.mcp_service.utils.schema_utils import parse_request -from superset.utils.core import get_user_id logger = logging.getLogger(__name__) @@ -52,7 +53,7 @@ async def remove_virtual_dataset( immediately reclaim resources. Args: - dataset_id: The virtual dataset ID to remove + dataset_id: The virtual dataset ID to remove (raw UUID or virtual:{uuid}) Returns: Success status and message. @@ -61,12 +62,8 @@ async def remove_virtual_dataset( try: registry = get_registry() - try: - user_id = get_user_id() - except Exception: - user_id = None - - session_id = getattr(ctx, "session_id", None) + normalized_dataset_id = normalize_virtual_dataset_id(request.dataset_id) + session_id, user_id = resolve_session_and_user(ctx) if session_id is None: if user_id is None: message = ( @@ -78,10 +75,11 @@ async def remove_virtual_dataset( success=False, message=message, ) - session_id = f"user_{user_id}" # Check if dataset exists dataset = registry.get( - request.dataset_id, session_id=session_id, user_id=user_id + normalized_dataset_id, + session_id=session_id, + user_id=user_id, ) if dataset is None: return RemoveVirtualDatasetResponse( @@ -94,7 +92,9 @@ async def remove_virtual_dataset( # Remove the dataset removed = registry.remove( - request.dataset_id, session_id=session_id, user_id=user_id + normalized_dataset_id, + session_id=session_id, + user_id=user_id, ) if removed: diff --git a/superset/mcp_service/dataframe/tool/source_adapters.py b/superset/mcp_service/dataframe/tool/source_adapters.py new file mode 100644 index 000000000000..b4a7bbec871c --- /dev/null +++ b/superset/mcp_service/dataframe/tool/source_adapters.py @@ -0,0 +1,220 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Source adapter abstractions for dataframe tools. + +This module formalizes source registration behavior and capability metadata +for DataFusion and related MCP dataframe tools. +""" + +from __future__ import annotations + +import base64 +from dataclasses import dataclass +from typing import Any, Protocol + +import pyarrow as pa + +from superset.mcp_service.dataframe.identifiers import normalize_virtual_dataset_id +from superset.mcp_service.dataframe.registry import VirtualDatasetRegistry +from superset.mcp_service.dataframe.schemas import DataFusionSourceConfig + + +def _register_arrow_table(session_ctx: Any, table_name: str, table: pa.Table) -> None: + """Register an Arrow table in DataFusion across supported APIs.""" + batches = table.to_batches() + + if hasattr(session_ctx, "register_record_batches"): + try: + session_ctx.register_record_batches(table_name, [batches]) + return + except TypeError: + session_ctx.register_record_batches(table_name, batches) + return + + if hasattr(session_ctx, "from_arrow_table"): + df = session_ctx.from_arrow_table(table) + if hasattr(df, "create_temp_view"): + df.create_temp_view(table_name) + return + + raise RuntimeError( + "DataFusion runtime does not support Arrow table registration " + "with available APIs" + ) + + +@dataclass(frozen=True) +class DataFrameSourceCapability: + """Capability metadata for a dataframe source adapter.""" + + source_type: str + adapter_name: str + supports_streaming: bool + supports_projection_pushdown: bool + supports_predicate_pushdown: bool + supports_sql_pushdown: bool + supports_virtual_dataset_ingestion: bool + + +class DataFrameSourceAdapter(Protocol): + """Contract for dataframe source registration adapters.""" + + source_type: str + capability: DataFrameSourceCapability + + def register_source( + self, + session_ctx: Any, + source: DataFusionSourceConfig, + registry: VirtualDatasetRegistry, + session_id: str | None, + user_id: int | None, + ) -> None: + """Register a source in an active DataFusion SessionContext.""" + + +class ParquetSourceAdapter: + """Adapter for parquet data sources.""" + + source_type = "parquet" + capability = DataFrameSourceCapability( + source_type=source_type, + adapter_name="ParquetSourceAdapter", + supports_streaming=True, + supports_projection_pushdown=True, + supports_predicate_pushdown=True, + supports_sql_pushdown=True, + supports_virtual_dataset_ingestion=True, + ) + + def register_source( + self, + session_ctx: Any, + source: DataFusionSourceConfig, + registry: VirtualDatasetRegistry, # noqa: ARG002 + session_id: str | None, # noqa: ARG002 + user_id: int | None, # noqa: ARG002 + ) -> None: + if source.path is None: + raise ValueError("Missing parquet path in source configuration") + session_ctx.register_parquet(source.name, source.path) + + +class ArrowIpcSourceAdapter: + """Adapter for Arrow IPC payloads.""" + + source_type = "arrow_ipc" + capability = DataFrameSourceCapability( + source_type=source_type, + adapter_name="ArrowIpcSourceAdapter", + supports_streaming=False, + supports_projection_pushdown=False, + supports_predicate_pushdown=False, + supports_sql_pushdown=False, + supports_virtual_dataset_ingestion=True, + ) + + def register_source( + self, + session_ctx: Any, + source: DataFusionSourceConfig, + registry: VirtualDatasetRegistry, # noqa: ARG002 + session_id: str | None, # noqa: ARG002 + user_id: int | None, # noqa: ARG002 + ) -> None: + if source.data is None: + raise ValueError("Missing Arrow IPC payload in source configuration") + raw_data = base64.b64decode(source.data) + reader = pa.ipc.open_stream(pa.BufferReader(raw_data)) + table = reader.read_all() + _register_arrow_table(session_ctx, source.name, table) + + +class VirtualDatasetSourceAdapter: + """Adapter for registered MCP virtual datasets.""" + + source_type = "virtual_dataset" + capability = DataFrameSourceCapability( + source_type=source_type, + adapter_name="VirtualDatasetSourceAdapter", + supports_streaming=False, + supports_projection_pushdown=False, + supports_predicate_pushdown=False, + supports_sql_pushdown=False, + supports_virtual_dataset_ingestion=True, + ) + + def register_source( + self, + session_ctx: Any, + source: DataFusionSourceConfig, + registry: VirtualDatasetRegistry, + session_id: str | None, + user_id: int | None, + ) -> None: + if source.dataset_id is None: + raise ValueError("Missing dataset_id in source configuration") + + raw_dataset_id = normalize_virtual_dataset_id(source.dataset_id) + dataset = registry.get( + raw_dataset_id, + session_id=session_id, + user_id=user_id, + ) + if dataset is None: + raise ValueError( + f"Virtual dataset '{source.dataset_id}' not found, " + "expired, or access denied" + ) + + _register_arrow_table(session_ctx, source.name, dataset.table) + + +DATAFUSION_SOURCE_ADAPTERS: dict[str, DataFrameSourceAdapter] = { + "parquet": ParquetSourceAdapter(), + "arrow_ipc": ArrowIpcSourceAdapter(), + "virtual_dataset": VirtualDatasetSourceAdapter(), +} + +PROMETHEUS_SOURCE_CAPABILITY = DataFrameSourceCapability( + source_type="prometheus_http", + adapter_name="PrometheusHttpAdapter", + supports_streaming=False, + supports_projection_pushdown=False, + supports_predicate_pushdown=True, + supports_sql_pushdown=False, + supports_virtual_dataset_ingestion=True, +) + + +def get_datafusion_source_adapter(source_type: str) -> DataFrameSourceAdapter | None: + """Get the registered DataFusion source adapter for source_type.""" + return DATAFUSION_SOURCE_ADAPTERS.get(source_type) + + +def list_datafusion_source_capabilities( + source_types: list[str] | None = None, +) -> list[DataFrameSourceCapability]: + """List capability metadata for known DataFusion source adapters.""" + if source_types is None: + source_types = list(DATAFUSION_SOURCE_ADAPTERS.keys()) + return [ + DATAFUSION_SOURCE_ADAPTERS[source_type].capability + for source_type in source_types + if source_type in DATAFUSION_SOURCE_ADAPTERS + ] diff --git a/superset/mcp_service/explore/tool/generate_explore_link.py b/superset/mcp_service/explore/tool/generate_explore_link.py index d1f07daf9295..ce53a50d65fd 100644 --- a/superset/mcp_service/explore/tool/generate_explore_link.py +++ b/superset/mcp_service/explore/tool/generate_explore_link.py @@ -35,6 +35,7 @@ from superset.mcp_service.chart.schemas import ( GenerateExploreLinkRequest, ) +from superset.mcp_service.dataframe.identifiers import is_virtual_dataset_identifier from superset.mcp_service.utils.schema_utils import parse_request @@ -115,6 +116,20 @@ async def generate_explore_link( # Generate explore link using shared utilities explore_url = generate_url(dataset_id=request.dataset_id, form_data=form_data) + if is_virtual_dataset_identifier(request.dataset_id): + message = ( + "Explore links are unavailable for virtual datasets. " + "Use generate_chart with preview formats " + "['table', 'ascii', 'vega_lite']." + ) + await ctx.warning(message) + return { + "url": "", + "form_data": form_data, + "form_data_key": None, + "error": message, + } + # Extract form_data_key from the explore URL using proper URL parsing form_data_key = None if explore_url: diff --git a/superset/mcp_service/run_proxy.sh b/superset/mcp_service/run_proxy.sh index 8a58d9db7e9d..5245454931c4 100755 --- a/superset/mcp_service/run_proxy.sh +++ b/superset/mcp_service/run_proxy.sh @@ -29,14 +29,16 @@ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" # Get the project root (two levels up from mcp_service) PROJECT_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)" -# Use python from the virtual environment if it exists, otherwise use system python +# Use python from the virtual environment if it exists, otherwise use uv if [ -f "$PROJECT_ROOT/venv/bin/python" ]; then - PYTHON_PATH="$PROJECT_ROOT/venv/bin/python" + PYTHON_CMD=("$PROJECT_ROOT/venv/bin/python") elif [ -f "$PROJECT_ROOT/.venv/bin/python" ]; then - PYTHON_PATH="$PROJECT_ROOT/.venv/bin/python" + PYTHON_CMD=("$PROJECT_ROOT/.venv/bin/python") +elif command -v uv >/dev/null 2>&1; then + PYTHON_CMD=("uv" "run" "python") else - PYTHON_PATH="python3" + PYTHON_CMD=("python") fi # Run the proxy script -"$PYTHON_PATH" "$SCRIPT_DIR/simple_proxy.py" +"${PYTHON_CMD[@]}" "$SCRIPT_DIR/simple_proxy.py" diff --git a/superset/views/api.py b/superset/views/api.py index 736c0b55483c..4e59c3abf57a 100644 --- a/superset/views/api.py +++ b/superset/views/api.py @@ -16,13 +16,17 @@ # under the License. from __future__ import annotations +import logging from typing import Any, TYPE_CHECKING +from urllib.parse import urljoin -from flask import request +import requests +from flask import current_app, request from flask_appbuilder import expose from flask_appbuilder.api import rison from flask_appbuilder.security.decorators import has_access_api from flask_babel import lazy_gettext as _ +from requests import RequestException from superset import db, event_logger from superset.commands.chart.exceptions import ( @@ -34,12 +38,15 @@ from superset.superset_typing import FlaskResponse from superset.utils import json from superset.utils.date_parser import get_since_until +from superset.utils.core import sanitize_svg_content from superset.views.base import api, BaseSupersetView from superset.views.error_handling import handle_api_exception if TYPE_CHECKING: from superset.common.query_context_factory import QueryContextFactory +logger = logging.getLogger(__name__) + get_time_range_schema = { "type": ["string", "array"], "items": { @@ -126,6 +133,117 @@ def time_range(self, **kwargs: Any) -> FlaskResponse: error_msg = {"message": _("Unexpected time range: %(error)s", error=error)} return self.json_response(error_msg, 400) + @event_logger.log_this + @api + @handle_api_exception + @has_access_api + @expose("/v1/kroki/render/", methods=("POST",)) + def kroki_render(self) -> FlaskResponse: + """ + Render a diagram via a Kroki sidecar and return SVG only. + """ + payload = request.get_json(silent=True) or {} + + diagram_type = str(payload.get("diagram_type", "")).strip().lower() + diagram_source = payload.get("diagram_source") + output_format = str(payload.get("output_format", "svg")).strip().lower() + + if output_format != "svg": + return self.json_response( + {"message": _("Kroki output_format must be svg.")}, + 400, + ) + + if not diagram_type: + return self.json_response( + {"message": _("Missing required field: diagram_type.")}, + 400, + ) + + if not isinstance(diagram_source, str) or not diagram_source.strip(): + return self.json_response( + {"message": _("Missing required field: diagram_source.")}, + 400, + ) + + max_source_length = int(current_app.config.get("KROKI_MAX_SOURCE_LENGTH", 0)) + if max_source_length > 0 and len(diagram_source) > max_source_length: + return self.json_response( + { + "message": _( + "diagram_source exceeds max length of %(length)s characters.", + length=max_source_length, + ) + }, + 400, + ) + + allowed_diagram_types = { + diagram.lower() + for diagram in current_app.config.get("KROKI_ALLOWED_DIAGRAM_TYPES", []) + if isinstance(diagram, str) and diagram.strip() + } + if allowed_diagram_types and diagram_type not in allowed_diagram_types: + return self.json_response( + { + "message": _( + "Unsupported diagram_type: %(diagram_type)s", + diagram_type=diagram_type, + ) + }, + 400, + ) + + kroki_base_url = str(current_app.config.get("KROKI_BASE_URL", "")).strip() + if not kroki_base_url: + return self.json_response( + {"message": _("KROKI_BASE_URL is not configured.")}, + 500, + ) + + kroki_timeout = int(current_app.config.get("KROKI_REQUEST_TIMEOUT", 10)) + kroki_url = urljoin(kroki_base_url.rstrip("/") + "/", f"{diagram_type}/svg") + + try: + response = requests.post( + kroki_url, + data=diagram_source.encode("utf-8"), + headers={ + "Accept": "image/svg+xml", + "Content-Type": "text/plain; charset=utf-8", + }, + timeout=kroki_timeout, + ) + except RequestException as ex: + logger.warning("Kroki sidecar request failed: %s", ex) + return self.json_response( + {"message": _("Failed to reach Kroki renderer.")}, + 502, + ) + + if response.status_code >= 400: + logger.warning( + "Kroki sidecar returned status %s for diagram_type=%s", + response.status_code, + diagram_type, + ) + return self.json_response( + {"message": _("Kroki failed to render diagram.")}, + 502, + ) + + svg = sanitize_svg_content(response.text) + return self.json_response( + { + "result": { + "diagram_type": diagram_type, + "output_format": "svg", + "svg": svg, + } + }, + 200, + ) + def get_query_context_factory(self) -> QueryContextFactory: if self.query_context_factory is None: # pylint: disable=import-outside-toplevel diff --git a/tests/unit_tests/mcp_service/chart/test_chart_utils.py b/tests/unit_tests/mcp_service/chart/test_chart_utils.py index eab027c3dcac..f7d8b10b7b2f 100644 --- a/tests/unit_tests/mcp_service/chart/test_chart_utils.py +++ b/tests/unit_tests/mcp_service/chart/test_chart_utils.py @@ -495,6 +495,11 @@ def test_generate_explore_link_exception_handling(self, mock_get_base_url) -> No == "http://localhost:9001/explore/?datasource_type=table&datasource_id=123" ) + def test_generate_explore_link_virtual_dataset(self) -> None: + """Virtual datasets do not produce Explore URLs.""" + result = generate_explore_link("virtual:abc123", {"viz_type": "table"}) + assert result == "" + class TestCriticalBugFixes: """Test critical bug fixes for chart utilities.""" diff --git a/tests/unit_tests/mcp_service/chart/test_virtual_dataset_bridge.py b/tests/unit_tests/mcp_service/chart/test_virtual_dataset_bridge.py new file mode 100644 index 000000000000..444c8d7bfd75 --- /dev/null +++ b/tests/unit_tests/mcp_service/chart/test_virtual_dataset_bridge.py @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from superset.mcp_service.chart.virtual_dataset_bridge import ( + build_virtual_dataset_query, +) + + +def test_build_virtual_dataset_query_raw_mode() -> None: + """Raw mode query selects configured columns with row limit.""" + sql = build_virtual_dataset_query( + { + "query_mode": "raw", + "all_columns": ["region", "sales"], + }, + limit=50, + ) + assert 'SELECT "region", "sales" FROM data' in sql + assert "LIMIT 50" in sql + + +def test_build_virtual_dataset_query_aggregate_with_filters() -> None: + """Aggregate mode query applies dimensions, metrics, and filters.""" + sql = build_virtual_dataset_query( + { + "x_axis": "region", + "metrics": [ + { + "aggregate": "SUM", + "column": {"column_name": "sales"}, + "label": "Total Sales", + } + ], + "adhoc_filters": [ + { + "expressionType": "SIMPLE", + "subject": "status", + "operator": "==", + "comparator": "active", + } + ], + }, + limit=100, + ) + assert 'SELECT "region", SUM("sales") AS "Total Sales" FROM data' in sql + assert "WHERE \"status\" = 'active'" in sql + assert 'GROUP BY "region"' in sql + assert "LIMIT 100" in sql diff --git a/tests/unit_tests/mcp_service/chart/tool/test_generate_chart.py b/tests/unit_tests/mcp_service/chart/tool/test_generate_chart.py index 62bd9bbd22f1..892033c50020 100644 --- a/tests/unit_tests/mcp_service/chart/tool/test_generate_chart.py +++ b/tests/unit_tests/mcp_service/chart/tool/test_generate_chart.py @@ -174,6 +174,12 @@ async def test_dataset_id_flexibility(self): chart_type="table", columns=[ColumnRef(name="col1")] ), ), + GenerateChartRequest( + dataset_id="virtual:abc123", + config=TableChartConfig( + chart_type="table", columns=[ColumnRef(name="col1")] + ), + ), ] for config in configs: diff --git a/tests/unit_tests/mcp_service/chart/validation/test_dataset_validator.py b/tests/unit_tests/mcp_service/chart/validation/test_dataset_validator.py new file mode 100644 index 000000000000..105ed5914468 --- /dev/null +++ b/tests/unit_tests/mcp_service/chart/validation/test_dataset_validator.py @@ -0,0 +1,84 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pyarrow as pa + +from superset.mcp_service.chart.schemas import ColumnRef, TableChartConfig +from superset.mcp_service.chart.validation.dataset_validator import DatasetValidator + + +def test_get_dataset_context_from_virtual_dataset() -> None: + """Virtual dataset identifiers resolve to virtual dataset context.""" + mock_registry = MagicMock() + mock_registry.get.return_value = SimpleNamespace( + id="abc123", + name="virtual_sales", + schema=pa.schema( + [ + pa.field("event_time", pa.timestamp("ms")), + pa.field("value", pa.float64()), + ] + ), + ) + + with patch( + "superset.mcp_service.chart.validation.dataset_validator.get_registry", + return_value=mock_registry, + ): + context = DatasetValidator._get_dataset_context( + dataset_id="virtual:abc123", + session_id="session_1", + user_id=42, + ) + + assert context is not None + assert context.id == "abc123" + assert context.table_name == "virtual_sales" + assert len(context.available_columns) == 2 + + +def test_validate_against_virtual_dataset_columns() -> None: + """Column validation uses virtual dataset schema for chart config checks.""" + mock_registry = MagicMock() + mock_registry.get.return_value = SimpleNamespace( + id="abc123", + name="virtual_sales", + schema=pa.schema([pa.field("region", pa.string())]), + ) + + config = TableChartConfig( + chart_type="table", + columns=[ColumnRef(name="missing_column")], + ) + + with patch( + "superset.mcp_service.chart.validation.dataset_validator.get_registry", + return_value=mock_registry, + ): + is_valid, error = DatasetValidator.validate_against_dataset( + config, + dataset_id="virtual:abc123", + session_id="session_1", + user_id=42, + ) + + assert not is_valid + assert error is not None + assert error.error_type == "column_not_found" diff --git a/tests/unit_tests/mcp_service/dataframe/test_identifiers.py b/tests/unit_tests/mcp_service/dataframe/test_identifiers.py new file mode 100644 index 000000000000..ce44fafdc108 --- /dev/null +++ b/tests/unit_tests/mcp_service/dataframe/test_identifiers.py @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from superset.mcp_service.dataframe.identifiers import ( + extract_virtual_dataset_id, + is_virtual_dataset_identifier, + normalize_virtual_dataset_id, + to_virtual_dataset_id, +) + + +def test_virtual_dataset_identifier_helpers() -> None: + """Virtual dataset identifier helpers normalize prefix handling.""" + assert is_virtual_dataset_identifier("virtual:abc123") + assert not is_virtual_dataset_identifier("abc123") + assert not is_virtual_dataset_identifier(123) + + assert normalize_virtual_dataset_id("virtual:abc123") == "abc123" + assert normalize_virtual_dataset_id("abc123") == "abc123" + + assert extract_virtual_dataset_id("virtual:abc123") == "abc123" + assert extract_virtual_dataset_id("abc123") is None + + assert to_virtual_dataset_id("abc123") == "virtual:abc123" + assert to_virtual_dataset_id("virtual:abc123") == "virtual:abc123" diff --git a/tests/unit_tests/mcp_service/dataframe/test_registry.py b/tests/unit_tests/mcp_service/dataframe/test_registry.py index c846f160bfb2..3b9a2f7086ea 100644 --- a/tests/unit_tests/mcp_service/dataframe/test_registry.py +++ b/tests/unit_tests/mcp_service/dataframe/test_registry.py @@ -199,10 +199,6 @@ def test_registry_list_datasets( # Listing without session_id or user_id should raise ValueError (security) with pytest.raises(ValueError, match="At least one of session_id or user_id"): registry.list_datasets() - - # List all datasets - all_datasets = registry.list_datasets() - assert len(all_datasets) == 3 # List datasets for session-1 session1_datasets = registry.list_datasets(session_id="session-1") diff --git a/tests/unit_tests/mcp_service/dataframe/test_schemas.py b/tests/unit_tests/mcp_service/dataframe/test_schemas.py index cc61ffd4d1ec..d7c4bf716d2b 100644 --- a/tests/unit_tests/mcp_service/dataframe/test_schemas.py +++ b/tests/unit_tests/mcp_service/dataframe/test_schemas.py @@ -19,15 +19,22 @@ from __future__ import annotations -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone import pytest from pydantic import ValidationError from superset.mcp_service.dataframe.schemas import ( ColumnSchema, + DataFrameSourceCapability, + DataFusionQueryResponse, + DataFusionQueryRequest, + DataFusionSourceConfig, IngestDataFrameRequest, IngestDataFrameResponse, + ListSourceCapabilitiesRequest, + ListSourceCapabilitiesResponse, + PrometheusQueryRequest, VirtualDatasetInfo, ) @@ -158,3 +165,117 @@ def test_ingest_dataframe_response_error() -> None: assert response.dataset is None assert response.error == "Dataset size exceeds limit" assert response.error_code == "SIZE_LIMIT_EXCEEDED" + + +def test_prometheus_query_request_defaults() -> None: + """Test PrometheusQueryRequest with minimal valid fields.""" + request = PrometheusQueryRequest( + base_url="http://prometheus:9090", + promql="up", + ) + assert request.query_type == "range" + assert request.step_seconds == 60 + assert request.ingest_as_virtual_dataset is True + + +def test_datafusion_source_config_validation() -> None: + """Test DataFusion source validation by source type.""" + parquet_source = DataFusionSourceConfig( + name="metrics", + source_type="parquet", + path="/tmp/data.parquet", + ) + assert parquet_source.path == "/tmp/data.parquet" + + with pytest.raises(ValidationError): + DataFusionSourceConfig(name="bad", source_type="parquet") + + with pytest.raises(ValidationError): + DataFusionSourceConfig(name="bad", source_type="arrow_ipc") + + with pytest.raises(ValidationError): + DataFusionSourceConfig(name="bad", source_type="virtual_dataset") + + +def test_datafusion_query_request_validation() -> None: + """Test DataFusion query request schema validation.""" + request = DataFusionQueryRequest( + sql="SELECT * FROM metrics", + sources=[ + DataFusionSourceConfig( + name="metrics", + source_type="parquet", + path="/tmp/data.parquet", + ) + ], + ingest_result=True, + ) + assert request.limit == 1000 + assert request.ingest_result is True + + with pytest.raises(ValidationError): + DataFusionQueryRequest(sql="SELECT 1", sources=[]) + + +def test_prometheus_query_request_time_window_validation() -> None: + """Test range query time window validation.""" + now = datetime.now(timezone.utc) + with pytest.raises(ValidationError): + PrometheusQueryRequest( + base_url="http://prometheus:9090", + promql="up", + query_type="range", + start_time=now, + end_time=now - timedelta(minutes=5), + ) + + +def test_datafusion_query_response_source_capabilities() -> None: + """DataFusion responses accept source capability metadata.""" + response = DataFusionQueryResponse( + success=True, + rows=[{"value": 1}], + columns=[{"name": "value", "type": "int64"}], + row_count=1, + source_capabilities=[ + DataFrameSourceCapability( + source_type="parquet", + adapter_name="ParquetSourceAdapter", + supports_streaming=True, + supports_projection_pushdown=True, + supports_predicate_pushdown=True, + supports_sql_pushdown=True, + supports_virtual_dataset_ingestion=True, + ) + ], + ) + assert len(response.source_capabilities) == 1 + assert response.source_capabilities[0].source_type == "parquet" + + +def test_list_source_capabilities_schemas() -> None: + """Capability list request/response schemas validate correctly.""" + request = ListSourceCapabilitiesRequest( + source_types=["parquet", "virtual_dataset"], + include_prometheus=False, + ) + assert request.source_types == ["parquet", "virtual_dataset"] + assert request.include_prometheus is False + + response = ListSourceCapabilitiesResponse( + success=True, + capabilities=[ + DataFrameSourceCapability( + source_type="parquet", + adapter_name="ParquetSourceAdapter", + supports_streaming=True, + supports_projection_pushdown=True, + supports_predicate_pushdown=True, + supports_sql_pushdown=True, + supports_virtual_dataset_ingestion=True, + ) + ], + total_count=1, + ) + assert response.success is True + assert response.total_count == 1 diff --git a/tests/unit_tests/mcp_service/dataframe/tool/test_list_source_capabilities.py b/tests/unit_tests/mcp_service/dataframe/tool/test_list_source_capabilities.py new file mode 100644 index 000000000000..abf9fc16c657 --- /dev/null +++ b/tests/unit_tests/mcp_service/dataframe/tool/test_list_source_capabilities.py @@ -0,0 +1,82 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from collections.abc import Generator +from unittest.mock import Mock, patch + +import pytest +from fastmcp import Client + +from superset.mcp_service.dataframe.schemas import ListSourceCapabilitiesRequest +from superset.mcp_service.app import mcp +from superset.utils import json + + +@pytest.fixture(autouse=True) +def mock_auth() -> Generator[None, None, None]: + """Mock authentication for all tests.""" + with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user: + mock_user = Mock() + mock_user.id = 1 + mock_user.username = "admin" + mock_get_user.return_value = mock_user + yield + + +@pytest.fixture +def mcp_server(): + return mcp + + +@pytest.mark.asyncio +async def test_list_source_capabilities_default(mcp_server) -> None: + """Default capability listing includes DataFusion and Prometheus sources.""" + async with Client(mcp_server) as client: + result = await client.call_tool( + "list_source_capabilities", + {"request": ListSourceCapabilitiesRequest().model_dump()}, + ) + + data = json.loads(result.content[0].text) + assert data["success"] is True + source_types = {cap["source_type"] for cap in data["capabilities"]} + assert { + "parquet", + "arrow_ipc", + "virtual_dataset", + "prometheus_http", + } <= source_types + + +@pytest.mark.asyncio +async def test_list_source_capabilities_filtered(mcp_server) -> None: + """Filtering source types and Prometheus flag scopes the capability output.""" + async with Client(mcp_server) as client: + result = await client.call_tool( + "list_source_capabilities", + { + "request": ListSourceCapabilitiesRequest( + source_types=["parquet"], + include_prometheus=False, + ).model_dump() + }, + ) + + data = json.loads(result.content[0].text) + assert data["success"] is True + assert data["total_count"] == 1 + assert data["capabilities"][0]["source_type"] == "parquet" diff --git a/tests/unit_tests/mcp_service/dataframe/tool/test_source_adapters.py b/tests/unit_tests/mcp_service/dataframe/tool/test_source_adapters.py new file mode 100644 index 000000000000..1b79c0622353 --- /dev/null +++ b/tests/unit_tests/mcp_service/dataframe/tool/test_source_adapters.py @@ -0,0 +1,80 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pyarrow as pa + +from superset.mcp_service.dataframe.schemas import DataFusionSourceConfig +from superset.mcp_service.dataframe.tool.source_adapters import ( + get_datafusion_source_adapter, + list_datafusion_source_capabilities, +) + + +def test_get_datafusion_source_adapter() -> None: + """Known source adapters are discoverable by source_type.""" + assert get_datafusion_source_adapter("parquet") is not None + assert get_datafusion_source_adapter("arrow_ipc") is not None + assert get_datafusion_source_adapter("virtual_dataset") is not None + assert get_datafusion_source_adapter("unknown") is None + + +def test_list_datafusion_source_capabilities() -> None: + """Capability metadata is returned for supported source types.""" + capabilities = list_datafusion_source_capabilities() + assert {cap.source_type for cap in capabilities} == { + "parquet", + "arrow_ipc", + "virtual_dataset", + } + + +def test_virtual_dataset_adapter_normalizes_prefixed_id() -> None: + """Virtual adapter accepts prefixed IDs and resolves raw registry IDs.""" + adapter = get_datafusion_source_adapter("virtual_dataset") + assert adapter is not None + + source = DataFusionSourceConfig( + name="source_table", + source_type="virtual_dataset", + dataset_id="virtual:abc123", + ) + registry = MagicMock() + registry.get.return_value = SimpleNamespace( + table=pa.table({"value": [1, 2, 3]}), + ) + + class SessionCtx: + def __init__(self) -> None: + self.calls: list[str] = [] + + def register_record_batches(self, table_name: str, batches: object) -> None: + self.calls.append(table_name) + + session_ctx = SessionCtx() + adapter.register_source( + session_ctx=session_ctx, + source=source, + registry=registry, + session_id="session_1", + user_id=42, + ) + + registry.get.assert_called_once_with("abc123", session_id="session_1", user_id=42) + assert session_ctx.calls == ["source_table"] diff --git a/tests/unit_tests/views/test_kroki_api.py b/tests/unit_tests/views/test_kroki_api.py new file mode 100644 index 000000000000..f9003f2d67d3 --- /dev/null +++ b/tests/unit_tests/views/test_kroki_api.py @@ -0,0 +1,102 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from unittest.mock import Mock, patch + +import pytest +from requests import RequestException + + +def test_kroki_render_success(client, full_api_access) -> None: + with patch("superset.views.api.requests.post") as post_mock: + post_mock.return_value = Mock( + status_code=200, + text='', + ) + + response = client.post( + "/api/v1/kroki/render/", + json={ + "diagram_type": "mermaid", + "diagram_source": "graph TD; A-->B;", + "output_format": "svg", + }, + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert payload["result"]["diagram_type"] == "mermaid" + assert payload["result"]["output_format"] == "svg" + assert " None: + response = client.post( + "/api/v1/kroki/render/", + json={ + "diagram_type": "mermaid", + "diagram_source": "graph TD; A-->B;", + "output_format": "png", + }, + ) + + assert response.status_code == 400 + payload = response.get_json() + assert payload is not None + assert "output_format" in payload["message"] + + +@pytest.mark.parametrize( + "diagram_type", + [ + "unsupported_type", + "", + ], +) +def test_kroki_render_rejects_invalid_diagram_type( + client, full_api_access, diagram_type: str +) -> None: + response = client.post( + "/api/v1/kroki/render/", + json={ + "diagram_type": diagram_type, + "diagram_source": "graph TD; A-->B;", + "output_format": "svg", + }, + ) + + assert response.status_code == 400 + + +def test_kroki_render_handles_sidecar_failure(client, full_api_access) -> None: + with patch("superset.views.api.requests.post") as post_mock: + post_mock.side_effect = RequestException("connection refused") + + response = client.post( + "/api/v1/kroki/render/", + json={ + "diagram_type": "mermaid", + "diagram_source": "graph TD; A-->B;", + "output_format": "svg", + }, + ) + + assert response.status_code == 502 + payload = response.get_json() + assert payload is not None + assert "Kroki" in payload["message"]