diff --git a/sqlmesh/core/engine_adapter/redshift.py b/sqlmesh/core/engine_adapter/redshift.py index c2a27954cd..39453f0cd2 100644 --- a/sqlmesh/core/engine_adapter/redshift.py +++ b/sqlmesh/core/engine_adapter/redshift.py @@ -4,6 +4,7 @@ import typing as t from sqlglot import exp +from sqlglot.helper import ensure_list from sqlmesh.core.dialect import to_schema from sqlmesh.core.engine_adapter.base import MERGE_SOURCE_ALIAS, MERGE_TARGET_ALIAS @@ -30,6 +31,7 @@ from sqlmesh.core._typing import SchemaName, TableName from sqlmesh.core.engine_adapter.base import QueryOrDF, Query + from sqlmesh.core.node import IntervalUnit logger = logging.getLogger(__name__) @@ -249,6 +251,63 @@ def create_view( **create_kwargs, ) + def _build_table_properties_exp( + self, + catalog_name: t.Optional[str] = None, + table_format: t.Optional[str] = None, + storage_format: t.Optional[str] = None, + partitioned_by: t.Optional[t.List[exp.Expr]] = None, + partition_interval_unit: t.Optional[IntervalUnit] = None, + clustered_by: t.Optional[t.List[exp.Expr]] = None, + table_properties: t.Optional[t.Dict[str, exp.Expr]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + table_description: t.Optional[str] = None, + table_kind: t.Optional[str] = None, + **kwargs: t.Any, + ) -> t.Optional[exp.Properties]: + properties: t.List[exp.Expr] = [] + + if table_description: + properties.append( + exp.SchemaCommentProperty( + this=exp.Literal.string(self._truncate_table_comment(table_description)) + ) + ) + + def _to_identifier_if_string(expression: exp.Expr) -> exp.Expr: + if isinstance(expression, exp.Literal) and expression.is_string: + return exp.to_identifier(expression.this) + return expression.copy() + + if table_properties: + table_properties = {k.upper(): v for k, v in table_properties.items()} + + table_type = self._pop_creatable_type_from_properties(table_properties) + properties.extend(ensure_list(table_type)) + + diststyle = table_properties.get("DISTSTYLE") + if diststyle: + properties.append(exp.DistStyleProperty(this=exp.var(diststyle.name.upper()))) + + distkey = table_properties.get("DISTKEY") + if distkey: + properties.append(exp.DistKeyProperty(this=_to_identifier_if_string(distkey))) + + sortkey = table_properties.get("SORTKEY") + if sortkey: + sortkey_expressions = sortkey.expressions if sortkey.expressions else [sortkey] + properties.append( + exp.SortKeyProperty( + this=[ + _to_identifier_if_string(expression) + for expression in sortkey_expressions + ], + compound=False, + ) + ) + + return exp.Properties(expressions=properties) if properties else None + def replace_query( self, table_name: TableName, diff --git a/tests/core/engine_adapter/test_redshift.py b/tests/core/engine_adapter/test_redshift.py index 5438943556..ddd2c7c2c8 100644 --- a/tests/core/engine_adapter/test_redshift.py +++ b/tests/core/engine_adapter/test_redshift.py @@ -8,8 +8,11 @@ from sqlglot import expressions as exp from sqlglot import parse_one +import sqlmesh.core.dialect as d from sqlmesh.core.engine_adapter import RedshiftEngineAdapter from sqlmesh.core.engine_adapter.shared import DataObject, DataObjectType +from sqlmesh.core.model import load_sql_based_model +from sqlmesh.core.model.definition import SqlModel from sqlmesh.utils.errors import SQLMeshError from tests.core.engine_adapter import to_sql_calls @@ -32,6 +35,158 @@ def test_columns(adapter: t.Callable): assert resp == {"col": exp.DataType.build("INT")} +def test_create_table_physical_properties(make_mocked_engine_adapter: t.Callable): + adapter = make_mocked_engine_adapter(RedshiftEngineAdapter) + + adapter.create_table( + "test_schema.test_table", + { + "id_file": exp.DataType.build("INT"), + "batch_time": exp.DataType.build("TIMESTAMP"), + }, + table_properties={ + "diststyle": exp.column("key"), + "distkey": exp.to_column("id_file"), + "sortkey": exp.to_column("batch_time"), + }, + ) + + assert to_sql_calls(adapter) == [ + 'CREATE TABLE IF NOT EXISTS "test_schema"."test_table" ("id_file" INTEGER, "batch_time" TIMESTAMP) DISTSTYLE KEY DISTKEY("id_file") SORTKEY("batch_time")', + ] + + +@pytest.mark.parametrize( + ("diststyle", "expected"), + [ + ("auto", "AUTO"), + ("even", "EVEN"), + ("key", "KEY"), + ("all", "ALL"), + ], +) +def test_create_table_physical_properties_diststyle_values( + make_mocked_engine_adapter: t.Callable, + diststyle: str, + expected: str, +): + adapter = make_mocked_engine_adapter(RedshiftEngineAdapter) + table_properties = {"diststyle": exp.column(diststyle)} + if diststyle == "key": + table_properties["distkey"] = exp.to_column("id_file") + + adapter.create_table( + "test_schema.test_table", + {"id_file": exp.DataType.build("INT")}, + table_properties=table_properties, + ) + + expected_distkey = ' DISTKEY("id_file")' if diststyle == "key" else "" + assert to_sql_calls(adapter) == [ + f'CREATE TABLE IF NOT EXISTS "test_schema"."test_table" ("id_file" INTEGER) DISTSTYLE {expected}{expected_distkey}', + ] + + +def test_create_table_physical_properties_distkey_without_diststyle( + make_mocked_engine_adapter: t.Callable, +): + adapter = make_mocked_engine_adapter(RedshiftEngineAdapter) + + adapter.create_table( + "test_schema.test_table", + {"id_file": exp.DataType.build("INT")}, + table_properties={"distkey": exp.to_column("id_file")}, + ) + + assert to_sql_calls(adapter) == [ + 'CREATE TABLE IF NOT EXISTS "test_schema"."test_table" ("id_file" INTEGER) DISTKEY("id_file")', + ] + + +def test_create_table_physical_properties_multi_column_sortkey( + make_mocked_engine_adapter: t.Callable, +): + adapter = make_mocked_engine_adapter(RedshiftEngineAdapter) + + adapter.create_table( + "test_schema.test_table", + { + "id_file": exp.DataType.build("INT"), + "batch_time": exp.DataType.build("TIMESTAMP"), + "event_time": exp.DataType.build("TIMESTAMP"), + }, + table_properties={ + "diststyle": exp.column("key"), + "distkey": exp.to_column("id_file"), + "sortkey": exp.Tuple( + expressions=[exp.to_column("batch_time"), exp.to_column("event_time")] + ), + }, + ) + + assert to_sql_calls(adapter) == [ + 'CREATE TABLE IF NOT EXISTS "test_schema"."test_table" ("id_file" INTEGER, "batch_time" TIMESTAMP, "event_time" TIMESTAMP) DISTSTYLE KEY DISTKEY("id_file") SORTKEY("batch_time", "event_time")', + ] + + +def test_create_table_physical_properties_with_string_columns( + make_mocked_engine_adapter: t.Callable, +): + adapter = make_mocked_engine_adapter(RedshiftEngineAdapter) + + adapter.create_table( + "test_schema.test_table", + { + "id_file": exp.DataType.build("INT"), + "batch_time": exp.DataType.build("TIMESTAMP"), + }, + table_properties={ + "diststyle": exp.Literal.string("key"), + "distkey": exp.Literal.string("id_file"), + "sortkey": exp.Literal.string("batch_time"), + }, + ) + + assert to_sql_calls(adapter) == [ + 'CREATE TABLE IF NOT EXISTS "test_schema"."test_table" ("id_file" INTEGER, "batch_time" TIMESTAMP) DISTSTYLE KEY DISTKEY("id_file") SORTKEY("batch_time")', + ] + + +def test_create_table_physical_properties_from_model_definition( + make_mocked_engine_adapter: t.Callable, +): + adapter = make_mocked_engine_adapter(RedshiftEngineAdapter) + model: SqlModel = t.cast( + SqlModel, + load_sql_based_model( + d.parse( + """ +MODEL ( + name test_schema.test_table, + kind full, + physical_properties ( + diststyle = key, + distkey = "id_file", + sortkey = "batch_time" + ) +); +SELECT id_file::INT, batch_time::TIMESTAMP; + """ + ) + ), + ) + + adapter.create_table( + model.name, + target_columns_to_types=model.columns_to_types_or_raise, + table_properties=model.physical_properties, + ) + + assert to_sql_calls(adapter) == [ + 'CREATE TABLE IF NOT EXISTS "test_schema"."test_table" ("id_file" INTEGER, "batch_time" TIMESTAMP) DISTSTYLE KEY DISTKEY("id_file") SORTKEY("batch_time")', + ] + + def test_varchar_size_workaround(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture): adapter = make_mocked_engine_adapter(RedshiftEngineAdapter)