Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
COUNT,
DISTINCT,
LITERAL,
RANDOM,
REGEX_LIKE,
STRING_HASH,
TUPLE,
Expand Down Expand Up @@ -243,3 +244,6 @@ def _build_concat_ws_sql(self, concat_ws: CONCAT_WS) -> str:

def _build_string_hash_sql(self, string_hash: STRING_HASH) -> str:
return f"to_hex({super()._build_string_hash_sql(string_hash)})"

def _build_random_sql(self, random: RANDOM) -> str:
return "RANDOM()"
5 changes: 5 additions & 0 deletions soda-core/src/soda_core/common/sql_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,11 @@ def __post_init__(self):
self.handle_parent_node_update(self.expression)


@dataclass
class RANDOM(SqlExpression):
"""Generates a random number in the range [0.0, 1.0)."""


@dataclass
class SqlExpressionStr(SqlExpression):
expression_str: str
Expand Down
37 changes: 28 additions & 9 deletions soda-core/src/soda_core/common/sql_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
ORDER_BY_ASC,
ORDER_BY_DESC,
ORDINAL_POSITION,
RANDOM,
RAW_SQL,
REGEX_LIKE,
SELECT,
Expand Down Expand Up @@ -778,6 +779,8 @@ def build_expression_sql(self, expression: SqlExpression | str | Number) -> str:
return self._build_star_sql(expression)
elif isinstance(expression, EXISTS):
return self._build_exists_sql(expression)
elif isinstance(expression, RANDOM):
return self._build_random_sql(expression)
raise Exception(f"Invalid expression type {expression.__class__.__name__}")

def _build_column_sql(self, column: COLUMN) -> str:
Expand Down Expand Up @@ -857,14 +860,15 @@ def _alias_format(self, alias: str) -> str:
def _build_from_part(self, from_part: FROM) -> str:
# "fully".qualified"."tablename" [AS "table_alias"]

from_parts: list[str] = [
self._build_qualified_quoted_dataset_name(
dataset_name=from_part.table_name, dataset_prefix=from_part.table_prefix
from_parts: list[str] = []
if from_part.sampler_type is not None and from_part.sample_size is not None:
from_parts.append(self._build_sample_sql(from_part))
else:
from_parts.append(
self._build_qualified_quoted_dataset_name(
dataset_name=from_part.table_name, dataset_prefix=from_part.table_prefix
)
)
]

if from_part.sampler_type is not None and isinstance(from_part.sample_size, Number):
from_parts.append(self._build_sample_sql(from_part.sampler_type, from_part.sample_size))

if isinstance(from_part.alias, str):
from_parts.append(self._alias_format(from_part.alias))
Expand Down Expand Up @@ -1146,8 +1150,23 @@ def format_expr(e: SqlExpression) -> SqlExpression:
string_to_hash = CONCAT_WS(separator="'||'", expressions=formatted_expressions)
return self.build_expression_sql(STRING_HASH(string_to_hash))

def _build_sample_sql(self, sampler_type: SamplerType, sample_size: Number) -> str:
raise NotImplementedError("Sampling not implemented for this dialect")
def _build_sample_sql(self, from_: FROM) -> str:
if from_.sampler_type is SamplerType.ABSOLUTE_LIMIT:
sql = self.build_select_sql(
[
SELECT(STAR()),
FROM(from_.table_name, from_.table_prefix),
ORDER_BY_ASC(RANDOM()),
LIMIT(from_.sample_size),
],
add_semicolon=False,
)
return f"({sql})"
else:
raise InvalidArgumentException(f"Unsupported sampler type for this dialect: {from_.sampler_type}")

def _build_random_sql(self, random: RANDOM) -> str:
return "RANDOM()"

def information_schema_namespace_elements(self, data_source_namespace: DataSourceNamespace) -> list[str]:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
ALTER_TABLE_ADD_COLUMN,
ALTER_TABLE_DROP_COLUMN,
CREATE_TABLE_COLUMN,
RANDOM,
)
from soda_core.common.sql_dialect import SqlDialect
from soda_core.common.statements.metadata_tables_query import MetadataTablesQuery
Expand Down Expand Up @@ -321,6 +322,9 @@ def convert_table_type_to_enum(self, table_type: str) -> TableType:
def metadata_casify(self, identifier: str) -> str:
return identifier.lower()

def _build_random_sql(self, random: RANDOM) -> str:
return "RAND()"


class DatabricksHiveSqlDialect(DatabricksSqlDialect):
def post_schema_create_sql(self, prefixes: list[str]) -> Optional[list[str]]:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from logging import Logger
from numbers import Number
from typing import TYPE_CHECKING, Optional

from soda_core.common.data_source_connection import DataSourceConnection
from soda_core.common.data_source_impl import DataSourceImpl
from soda_core.common.logging_constants import soda_logger
from soda_core.common.metadata_types import SamplerType, SodaDataTypeName
from soda_core.common.sql_ast import COLUMN, COUNT, DISTINCT, TUPLE, VALUES
from soda_core.common.sql_ast import COLUMN, COUNT, DISTINCT, FROM, TUPLE, VALUES
from soda_core.common.sql_dialect import SqlDialect
from soda_core.contracts.impl.contract_verification_impl import ContractImpl
from soda_snowflake.common.data_sources.snowflake_data_source_connection import (
Expand Down Expand Up @@ -192,8 +191,22 @@ def data_type_has_parameter_datetime_precision(self, data_type_name) -> bool:
TIMESTAMP_WITH_LOCAL_TIME_ZONE,
]

def _build_sample_sql(self, sampler_type: SamplerType, sample_size: Number) -> str:
if sampler_type is SamplerType.ABSOLUTE_LIMIT:
return f"TABLESAMPLE ({int(sample_size)} ROWS)"
def _build_from_part(self, from_part: FROM) -> str:
# Snowflake places the sampling clause after the alias, so we need to override the method
from_parts: list[str] = []
from_parts.append(
self._build_qualified_quoted_dataset_name(
dataset_name=from_part.table_name, dataset_prefix=from_part.table_prefix
)
)
if from_part.alias is not None:
from_parts.append(self._alias_format(from_part.alias))
if from_part.sampler_type is not None and from_part.sample_size is not None:
from_parts.append(self._build_sample_sql(from_part))
return " ".join(from_parts)

def _build_sample_sql(self, from_: FROM) -> str:
if from_.sampler_type is SamplerType.ABSOLUTE_LIMIT:
return f"TABLESAMPLE ({int(from_.sample_size)} ROWS)"
else:
raise ValueError(f"Unsupported sample type: {sampler_type}")
raise ValueError(f"Unsupported sample type: {from_.sampler_type}")
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
LIMIT,
OFFSET,
ORDER_BY_ASC,
RANDOM,
REGEX_LIKE,
SELECT,
STAR,
Expand Down Expand Up @@ -425,3 +426,6 @@ def build_create_view_sql(
# Drop the first prefix (database name) from the fully qualified view name
create_view_copy.fully_qualified_view_name = ".".join(create_view_copy.fully_qualified_view_name.split(".")[1:])
return super().build_create_view_sql(create_view_copy, add_semicolon, add_parenthesis=False)

def _build_random_sql(self, random: RANDOM) -> str:
return "RAND()"
30 changes: 19 additions & 11 deletions soda-tests/tests/features/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from soda_core.common.soda_cloud_dto import DatasetConfigurationDTO
from soda_core.common.soda_cloud_dto import SamplerType as SamplerTypeDTO
from soda_core.common.soda_cloud_dto import TestRowSamplerConfigurationDTO
from soda_core.common.sql_ast import FROM
from soda_core.contracts.contract_verification import ContractVerificationResult

test_table_specification = (
Expand Down Expand Up @@ -233,11 +234,19 @@ def test_sampling_custom_sql_pass(
table_full_name = test_table.qualified_name
sample_size = 3

def add_sample(name: str) -> str:
return f"{name} {data_source_test_helper.data_source_impl.sql_dialect._build_sample_sql(SamplerType.ABSOLUTE_LIMIT, sample_size)}"
def from_with_sampling(alias: str) -> str:
return data_source_test_helper.data_source_impl.sql_dialect._build_from_part(
FROM(
test_table.unique_name,
test_table.dataset_prefix,
alias=alias,
sampler_type=SamplerType.ABSOLUTE_LIMIT,
sample_size=sample_size,
)
)

def build_name_with_alias(name: str, alias: str) -> str:
return f"{name} AS {alias}"
def table_with_alias(alias: str) -> str:
return f"{table_full_name} AS {alias}"

data_source_test_helper.soda_cloud.set_dataset_configuration_response(
dataset_identifier=test_table.dataset_identifier,
Expand All @@ -254,12 +263,12 @@ def build_name_with_alias(name: str, alias: str) -> str:
checks:
- metric:
query: |
select avg({age_quoted}) from {build_name_with_alias(table_full_name, "metric_query")} where 1 = 1
select avg({age_quoted}) from {table_with_alias("metric_query")} where 1 = 1
threshold:
must_be_greater_than: 0
- failed_rows:
query: |
select * from {build_name_with_alias(table_full_name, "fr_query")} where {age_quoted} > 100
select * from {table_with_alias("fr_query")} where {age_quoted} > 100
""",
)

Expand All @@ -269,16 +278,15 @@ def build_name_with_alias(name: str, alias: str) -> str:
logs = contract_verification_result.get_logs_str().lower()

assert (
f"from {build_name_with_alias(table_full_name, 'metric_query')} where 1 = 1".lower() not in logs
f"from {table_with_alias('metric_query')} where 1 = 1".lower() not in logs
), "Original metric query should not be in logs"
assert (
f"from {add_sample(build_name_with_alias(table_full_name, 'metric_query'))}\nwhere\n 1 = 1".lower() in logs
f"from {from_with_sampling('metric_query')}\nwhere\n 1 = 1".lower() in logs
), "Sampled metric query should be in logs"

assert (
f"from {build_name_with_alias(table_full_name, 'fr_query')} where {age_quoted} > 100".lower() not in logs
f"from {table_with_alias('fr_query')} where {age_quoted} > 100".lower() not in logs
), "Original failed_rows query should not be in logs"
assert (
f"from {add_sample(build_name_with_alias(table_full_name, 'fr_query'))}\n where\n {age_quoted} > 100".lower()
in logs
f"from {from_with_sampling('fr_query')}\n where\n {age_quoted} > 100".lower() in logs
), "Sampled failed_rows query should be in logs"
Loading