diff --git a/requirements.txt b/requirements.txt index 280274771..cbdf05784 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,7 +8,7 @@ ./soda/spark[hive] ./soda/spark[odbc] ./soda/spark[databricks] -./soda/spark_df +./soda/spark_df[pyspark] ./soda/scientific[simulator] ./soda/sqlserver ./soda/mysql diff --git a/soda/spark_df/setup.py b/soda/spark_df/setup.py index 9c60d5139..6d4305e2b 100644 --- a/soda/spark_df/setup.py +++ b/soda/spark_df/setup.py @@ -8,12 +8,14 @@ requires = [ f"soda-core-spark=={package_version}", - "pyspark>=3.4.0", ] +extras = {"pyspark": ["pyspark>=3.4.0"]} + # TODO Fix the params setup( name=package_name, version=package_version, install_requires=requires, + extras_require=extras, packages=find_namespace_packages(include=["soda*"]), ) diff --git a/soda/spark_df/soda/data_sources/spark_df_connection.py b/soda/spark_df/soda/data_sources/spark_df_connection.py index 752fe0a1e..2811a11a4 100644 --- a/soda/spark_df/soda/data_sources/spark_df_connection.py +++ b/soda/spark_df/soda/data_sources/spark_df_connection.py @@ -1,9 +1,13 @@ -from pyspark.sql.session import SparkSession +from typing import TYPE_CHECKING + from soda.data_sources.spark_df_cursor import SparkDfCursor +if TYPE_CHECKING: + from pyspark.sql.session import SparkSession + class SparkDfConnection: - def __init__(self, spark_session: SparkSession): + def __init__(self, spark_session: "SparkSession"): self.spark_session = spark_session def cursor(self) -> SparkDfCursor: diff --git a/soda/spark_df/soda/data_sources/spark_df_contract_data_source.py b/soda/spark_df/soda/data_sources/spark_df_contract_data_source.py index c4322fd8b..57d4fc0a4 100644 --- a/soda/spark_df/soda/data_sources/spark_df_contract_data_source.py +++ b/soda/spark_df/soda/data_sources/spark_df_contract_data_source.py @@ -2,8 +2,8 @@ import logging import re +from typing import TYPE_CHECKING -from pyspark.sql import SparkSession from soda.data_sources.spark_df_connection import SparkDfConnection from soda.execution.data_type import DataType @@ -13,9 +13,11 @@ logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from pyspark.sql import SparkSession -class SparkDfSqlDialect(SqlDialect): +class SparkDfSqlDialect(SqlDialect): def __init__(self): super().__init__() @@ -87,7 +89,6 @@ def regex_replace_flags(self) -> str: class SparkDfContractDataSource(FileClContractDataSource): - def __init__(self, data_source_yaml_file: YamlFile, spark_session: SparkSession): data_source_yaml_dict: dict = data_source_yaml_file.get_dict() data_source_yaml_dict[self._KEY_TYPE] = "spark_df" diff --git a/soda/spark_df/soda/data_sources/spark_df_cursor.py b/soda/spark_df/soda/data_sources/spark_df_cursor.py index a73a8e67a..d16997561 100644 --- a/soda/spark_df/soda/data_sources/spark_df_cursor.py +++ b/soda/spark_df/soda/data_sources/spark_df_cursor.py @@ -1,7 +1,10 @@ from __future__ import annotations -from pyspark.sql import DataFrame, SparkSession -from pyspark.sql.types import Row +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pyspark.sql import DataFrame, SparkSession + from pyspark.sql.types import Row class SparkDfCursor: