diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c64efa15e..ce5971d21 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -67,4 +67,4 @@ repos: language: node pass_filenames: false types: [python] - additional_dependencies: ['pyright@1.1.291'] + additional_dependencies: ['pyright@1.1.398'] diff --git a/lale/datasets/data_schemas.py b/lale/datasets/data_schemas.py index b0a58dde1..afcc6f7ae 100644 --- a/lale/datasets/data_schemas.py +++ b/lale/datasets/data_schemas.py @@ -315,8 +315,8 @@ def get_index_name(obj): return result -def get_index_names(obj): - result = None +def get_index_names(obj) -> Optional[list[str]]: + result: Optional[list[str]] = None if SparkDataFrame is not None and isinstance(obj, SparkDataFrameWithIndex): result = obj.index_names elif isinstance( diff --git a/lale/datasets/util.py b/lale/datasets/util.py index ff6877d68..d9ee9fae9 100644 --- a/lale/datasets/util.py +++ b/lale/datasets/util.py @@ -19,26 +19,27 @@ import pandas as pd from sklearn.utils import Bunch -from lale.datasets.data_schemas import add_table_name, get_table_name - -download_data_cache_dir: pathlib.Path = pathlib.Path( - os.environ.get("LALE_DOWNLOAD_CACHE_DIR", os.path.dirname(__file__)) +from lale.datasets.data_schemas import ( + SparkDataFrameWithIndex, + add_table_name, + get_table_name, ) try: from pyspark.sql import SparkSession - from lale.datasets.data_schemas import ( # pylint:disable=ungrouped-imports - SparkDataFrameWithIndex, - ) - - spark_installed = True except ImportError: - spark_installed = False + SparkSession = None + + +download_data_cache_dir: pathlib.Path = pathlib.Path( + os.environ.get("LALE_DOWNLOAD_CACHE_DIR", os.path.dirname(__file__)) +) def pandas2spark(pandas_df): - assert spark_installed + assert SparkSession is not None + spark_session = ( SparkSession.builder.master("local[2]") # type: ignore .config("spark.driver.memory", "64g") diff --git a/lale/expressions.py b/lale/expressions.py index 29c1dbc73..40e6a4917 100644 --- a/lale/expressions.py +++ b/lale/expressions.py @@ -12,6 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Due to changes in the ast package, the types are confused, but the code still works +# changing the types presents compatibility challenges for now, so some pyright errors are disabled for now +# pyright: reportArgumentType=false +# pyright: reportIncompatibleMethodOverride=false + + import ast # see also https://greentreesnakes.readthedocs.io/ import pprint import typing diff --git a/lale/helpers.py b/lale/helpers.py index e0f817a24..a4eb7b206 100644 --- a/lale/helpers.py +++ b/lale/helpers.py @@ -55,6 +55,8 @@ spark_installed = spark_loader is not None if spark_installed: from pyspark.sql.dataframe import DataFrame as spark_df +else: + spark_df = None logger = logging.getLogger(__name__) @@ -1252,10 +1254,7 @@ def _is_spark_df(df): def _is_spark_df_without_index(df): - if spark_installed: - return isinstance(df, spark_df) and not _is_spark_df(df) - else: - return False + return spark_df is not None and isinstance(df, spark_df) and not _is_spark_df(df) def _ensure_pandas(df) -> pd.DataFrame: diff --git a/lale/lib/category_encoders/target_encoder.py b/lale/lib/category_encoders/target_encoder.py index 67cacf49d..717a53b51 100644 --- a/lale/lib/category_encoders/target_encoder.py +++ b/lale/lib/category_encoders/target_encoder.py @@ -19,10 +19,12 @@ try: import category_encoders + from category_encoders import TargetEncoder as Base catenc_version = version.parse(getattr(category_encoders, "__version__")) except ImportError: + Base = None catenc_version = None import lale.docstrings @@ -169,11 +171,11 @@ class _TargetEncoderImpl: def __init__(self, **hyperparams): - if catenc_version is None: + if catenc_version is None or Base is None: raise ValueError("The package 'category_encoders' is not installed.") with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=FutureWarning) - self._wrapped_model = category_encoders.TargetEncoder(**hyperparams) + self._wrapped_model = Base(**hyperparams) def fit(self, X, y): if catenc_version is None: diff --git a/lale/lib/lale/grid_search_cv.py b/lale/lib/lale/grid_search_cv.py index 62a88156e..cbbd6fec4 100644 --- a/lale/lib/lale/grid_search_cv.py +++ b/lale/lib/lale/grid_search_cv.py @@ -29,13 +29,11 @@ from .observing import Observing -func_timeout_installed = False try: from func_timeout import FunctionTimedOut, func_timeout - - func_timeout_installed = True except ImportError: - pass + FunctionTimedOut = None + func_timeout = None class _GridSearchCVImpl: @@ -152,7 +150,7 @@ def fit(self, X, y, **fit_params): n_jobs=self._hyperparams["n_jobs"], ) if self._hyperparams["max_opt_time"] is not None: - if func_timeout_installed: + if func_timeout is not None and FunctionTimedOut is not None: try: func_timeout( self._hyperparams["max_opt_time"], self.grid.fit, (X, y) diff --git a/lale/lib/lale/halving_grid_search_cv.py b/lale/lib/lale/halving_grid_search_cv.py index df28d3d44..0f16936db 100644 --- a/lale/lib/lale/halving_grid_search_cv.py +++ b/lale/lib/lale/halving_grid_search_cv.py @@ -34,9 +34,9 @@ try: from func_timeout import FunctionTimedOut, func_timeout - func_timeout_installed = True except ImportError: - pass + FunctionTimedOut = None + func_timeout = None class _HalvingGridSearchCVImpl: @@ -170,7 +170,7 @@ def fit(self, X, y, **fit_params): n_jobs=self._hyperparams["n_jobs"], ) if self._hyperparams["max_opt_time"] is not None: - if func_timeout_installed: + if func_timeout is not None and FunctionTimedOut is not None: try: func_timeout( self._hyperparams["max_opt_time"], self.grid.fit, (X, y) diff --git a/lale/lib/lale/smac.py b/lale/lib/lale/smac.py index c5ffa9d09..8546d0a13 100644 --- a/lale/lib/lale/smac.py +++ b/lale/lib/lale/smac.py @@ -40,21 +40,8 @@ with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=FutureWarning) from smac.configspace import ConfigurationSpace - - # Import SMAC-utilities - from smac.facade.smac_facade import SMAC as orig_SMAC - from smac.scenario.scenario import Scenario - from smac.tae.execute_ta_run import BudgetExhaustedException - - from lale.search.lale_smac import ( # pylint:disable=wrong-import-position,ungrouped-imports - get_smac_space, - lale_op_smac_tae, - lale_trainable_op_from_config, - ) - - smac_installed = True except ImportError: - smac_installed = False + ConfigurationSpace = None logger = logging.getLogger(__name__) @@ -72,7 +59,9 @@ def __init__( max_opt_time=None, lale_num_grids=None, ): - assert smac_installed, """Your Python environment does not have smac installed. You can install it with + assert ( + ConfigurationSpace is not None + ), """Your Python environment does not have smac installed. You can install it with pip install smac<=0.10.0 or with pip install 'lale[full]'""" @@ -98,10 +87,22 @@ def __init__( self.trials = None def fit(self, X_train, y_train, **fit_params): + + # Import SMAC-utilities + from smac.facade.smac_facade import SMAC as orig_SMAC + from smac.scenario.scenario import Scenario + from smac.tae.execute_ta_run import BudgetExhaustedException + + from lale.search.lale_smac import ( # pylint:disable=wrong-import-position,ungrouped-imports + get_smac_space, + lale_op_smac_tae, + lale_trainable_op_from_config, + ) + data_schema = lale.helpers.fold_schema( X_train, y_train, self.cv, self.estimator.is_classifier() ) - self.search_space: ConfigurationSpace = get_smac_space( + self.search_space = get_smac_space( self.estimator, lale_num_grids=self.lale_num_grids, data_schema=data_schema ) # Scenario object diff --git a/lale/lib/rasl/_eval_spark_df.py b/lale/lib/rasl/_eval_spark_df.py index 45bd5c789..d8d550186 100644 --- a/lale/lib/rasl/_eval_spark_df.py +++ b/lale/lib/rasl/_eval_spark_df.py @@ -18,16 +18,13 @@ from lale.helpers import _ast_func_id try: - import pyspark.sql.functions - # noqa in the imports here because those get used dynamically and flake fails. - from pyspark.sql.functions import col # noqa - from pyspark.sql.functions import lit # noqa - from pyspark.sql.functions import to_timestamp # noqa - from pyspark.sql.functions import hour as spark_hour # noqa - from pyspark.sql.functions import isnan, isnull # noqa - from pyspark.sql.functions import minute as spark_minute # noqa - from pyspark.sql.functions import month as spark_month # noqa + from pyspark.sql.functions import col + from pyspark.sql.functions import hour as spark_hour + from pyspark.sql.functions import isnan, isnull, lit + from pyspark.sql.functions import minute as spark_minute + from pyspark.sql.functions import month as spark_month + from pyspark.sql.functions import to_timestamp from pyspark.sql.types import LongType from pyspark.sql.functions import ( # noqa; isort: skip @@ -42,6 +39,26 @@ spark_installed = True except ImportError: + lit = None + col = None + to_timestamp = None + isnan = None + isnull = None + spark_udf = None + spark_floor = None + LongType = None + spark_when = None + dayofmonth = None + dayofweek = None + dayofyear = None + spark_floor = None + spark_md5 = None + spark_udf = None + spark_when = None + spark_hour = None + spark_minute = None + spark_month = None + spark_installed = False @@ -59,21 +76,18 @@ class _SparkEvaluator(ast.NodeVisitor): def __init__(self): self.result = None - def visit_Num(self, node: ast.Num): - self.result = lit(node.n) - - def visit_Str(self, node: ast.Str): - self.result = lit(node.s) - def visit_Constant(self, node: ast.Constant): + assert lit is not None self.result = lit(node.value) def visit_Attribute(self, node: ast.Attribute): column_name = _it_column(node) + assert col is not None self.result = col(column_name) # type: ignore def visit_Subscript(self, node: ast.Subscript): column_name = _it_column(node) + assert col is not None self.result = col(column_name) # type: ignore def visit_BinOp(self, node: ast.BinOp): @@ -92,6 +106,7 @@ def visit_BinOp(self, node: ast.BinOp): elif isinstance(node.op, ast.Div): self.result = v1 / v2 elif isinstance(node.op, ast.FloorDiv): + assert spark_floor is not None self.result = spark_floor(v1 / v2) elif isinstance(node.op, ast.Mod): self.result = v1 % v2 @@ -155,7 +170,9 @@ def hash(call: ast.Call): # pylint:disable=redefined-builtin hashing_method = ast.literal_eval(call.args[0]) column = _eval_ast_expr_spark_df(call.args[1]) # type: ignore if hashing_method == "md5": - hash_fun = spark_md5(column) # type: ignore + assert spark_md5 is not None + assert column is not None + hash_fun = spark_md5(column) else: raise ValueError(f"Unimplementade hash function in Spark: {hashing_method}") return hash_fun @@ -164,6 +181,8 @@ def hash(call: ast.Call): # pylint:disable=redefined-builtin def hash_mod(call: ast.Call): h_column = hash(call) N = ast.literal_eval(call.args[2]) + assert spark_udf is not None + assert LongType is not None int16_mod_N = spark_udf((lambda x: int(x, 16) % N), LongType()) return int16_mod_N(h_column) @@ -188,18 +207,23 @@ def replace(call: ast.Call): handle_unknown = ast.literal_eval(call.args[2]) chain_of_whens = None + assert column is not None for key, value in mapping_dict.items(): if key == "nan": - when_expr = isnan(column) # type: ignore + assert isnan is not None + when_expr = isnan(column) elif key is None: - when_expr = isnull(column) # type: ignore + assert isnull is not None + when_expr = isnull(column) else: - when_expr = column == key # type: ignore + when_expr = column == key if chain_of_whens is None: - chain_of_whens = pyspark.sql.functions.when(when_expr, value) + assert spark_when is not None + chain_of_whens = spark_when(when_expr, value) else: chain_of_whens = chain_of_whens.when(when_expr, value) if handle_unknown == "use_encoded_value": + assert lit is not None fallback = lit(ast.literal_eval(call.args[3])) else: fallback = column @@ -216,31 +240,39 @@ def identity(call: ast.Call): def time_functions(call, spark_func): column = _eval_ast_expr_spark_df(call.args[0]) + assert to_timestamp is not None + assert column is not None if len(call.args) > 1: fmt = ast.literal_eval(call.args[1]) - return spark_func(to_timestamp(column, format=fmt)) # type: ignore - return spark_func(to_timestamp(column)) # type: ignore + return spark_func(to_timestamp(column, format=fmt)) + return spark_func(to_timestamp(column)) def day_of_month(call: ast.Call): + assert dayofmonth is not None return time_functions(call, dayofmonth) def day_of_week(call: ast.Call): + assert dayofweek is not None return time_functions(call, dayofweek) def day_of_year(call: ast.Call): + assert dayofyear is not None return time_functions(call, dayofyear) def hour(call: ast.Call): + assert spark_hour is not None return time_functions(call, spark_hour) def minute(call: ast.Call): + assert spark_minute is not None return time_functions(call, spark_minute) def month(call: ast.Call): + assert spark_month is not None return time_functions(call, spark_month) diff --git a/lale/lib/rasl/concat_features.py b/lale/lib/rasl/concat_features.py index e6d9aa580..93cd8cddd 100644 --- a/lale/lib/rasl/concat_features.py +++ b/lale/lib/rasl/concat_features.py @@ -36,11 +36,9 @@ try: - import torch - - torch_installed = True + from torch import Tensor except ImportError: - torch_installed = False + Tensor = None def _is_pandas_df(d): @@ -122,7 +120,7 @@ def join(d1, d2): np_dataset = dataset.toPandas().values elif isinstance(dataset, scipy.sparse.csr_matrix): np_dataset = dataset.toarray() - elif torch_installed and isinstance(dataset, torch.Tensor): + elif Tensor is not None and isinstance(dataset, Tensor): np_dataset = dataset.detach().cpu().numpy() else: np_dataset = dataset diff --git a/lale/lib/rasl/datasets.py b/lale/lib/rasl/datasets.py index 5042b3b1b..ea083d27f 100644 --- a/lale/lib/rasl/datasets.py +++ b/lale/lib/rasl/datasets.py @@ -39,6 +39,7 @@ else: _PandasOrSparkBatchAux = _PandasBatch # type: ignore + DataFrame = None # pyright does not currently accept a TypeAlias with conditional definitions _PandasOrSparkBatch: TypeAlias = _PandasOrSparkBatchAux # type: ignore @@ -48,16 +49,16 @@ from lale.datasets.openml import openml_datasets # pylint:disable=ungrouped-imports - liac_arff_installed = True except ModuleNotFoundError: - liac_arff_installed = False + arff = None + openml_datasets = None # type: ignore def arff_data_loader( file_name: str, label_name: str, rows_per_batch: int ) -> Iterable[_PandasBatch]: """Incrementally load an ARFF file and yield it one (X, y) batch at a time.""" - assert liac_arff_installed + assert arff is not None split_x_y = SplitXy(label_name=label_name) def make_batch(): @@ -144,7 +145,8 @@ def mockup_data_loader( def openml_data_loader(dataset_name: str, batch_size: int) -> Iterable[_PandasBatch]: """Download the OpenML dataset, incrementally load it, and yield it one (X,y) batch at a time.""" - assert liac_arff_installed + assert arff is not None + assert openml_datasets is not None metadata = openml_datasets.experiments_dict[dataset_name] label_name = cast(str, metadata["target"]).lower() file_name = openml_datasets.download_if_missing(dataset_name) diff --git a/lale/lib/rasl/filter.py b/lale/lib/rasl/filter.py index d2d748d34..fc62caa21 100644 --- a/lale/lib/rasl/filter.py +++ b/lale/lib/rasl/filter.py @@ -32,10 +32,9 @@ try: from pyspark.sql.functions import col - spark_installed = True - except ImportError: - spark_installed = False + + col = None class _FilterImpl: @@ -120,6 +119,7 @@ def filter_fun(X): # Filtering spark dataframes if _is_spark_df(X): + assert col is not None if isinstance(op, ast.Eq): assert lhs is not None assert rhs is not None diff --git a/lale/lib/rasl/functions.py b/lale/lib/rasl/functions.py index be1ce3ef5..fd039a840 100644 --- a/lale/lib/rasl/functions.py +++ b/lale/lib/rasl/functions.py @@ -30,11 +30,11 @@ from .monoid import Monoid, MonoidFactory try: - import pyspark.sql.functions - - spark_installed = True + from pyspark.sql.functions import isnan as spark_isnan + from pyspark.sql.functions import isnull as spark_isnull except ImportError: - spark_installed = False + spark_isnan = None + spark_isnull = None class _column_distinct_count_data(Monoid): @@ -233,8 +233,8 @@ def is_date_time(column_values): def filter_isnan(df: Any, column_name: str): if _is_pandas_df(df): return df[df[column_name].isnull()] - elif spark_installed and _is_spark_df(df): - return df.filter(pyspark.sql.functions.isnan(df[column_name])) + elif spark_isnan is not None and _is_spark_df(df): + return df.filter(spark_isnan(df[column_name])) else: raise ValueError( "the filter isnan supports only Pandas dataframes or spark dataframes." @@ -244,8 +244,8 @@ def filter_isnan(df: Any, column_name: str): def filter_isnotnan(df: Any, column_name: str): if _is_pandas_df(df): return df[df[column_name].notnull()] - elif spark_installed and _is_spark_df(df): - return df.filter(~pyspark.sql.functions.isnan(df[column_name])) + elif spark_isnan is not None and _is_spark_df(df): + return df.filter(~spark_isnan(df[column_name])) else: raise ValueError( "the filter isnotnan supports only Pandas dataframes or spark dataframes." @@ -255,8 +255,8 @@ def filter_isnotnan(df: Any, column_name: str): def filter_isnull(df: Any, column_name: str): if _is_pandas_df(df): return df[df[column_name].isnull()] - elif spark_installed and _is_spark_df(df): - return df.filter(pyspark.sql.functions.isnull(df[column_name])) + elif spark_isnull is not None and _is_spark_df(df): + return df.filter(spark_isnull(df[column_name])) else: raise ValueError( "the filter isnan supports only Pandas dataframes or spark dataframes." @@ -266,8 +266,8 @@ def filter_isnull(df: Any, column_name: str): def filter_isnotnull(df: Any, column_name: str): if _is_pandas_df(df): return df[df[column_name].notnull()] - elif spark_installed and _is_spark_df(df): - return df.filter(~pyspark.sql.functions.isnull(df[column_name])) + elif spark_isnull is not None and _is_spark_df(df): + return df.filter(~spark_isnull(df[column_name])) else: raise ValueError( "the filter isnotnan supports only Pandas dataframes or spark dataframes." diff --git a/lale/lib/rasl/group_by.py b/lale/lib/rasl/group_by.py index c5e92553d..ca3224c87 100644 --- a/lale/lib/rasl/group_by.py +++ b/lale/lib/rasl/group_by.py @@ -57,7 +57,8 @@ def transform(self, X): if _is_pandas_df(X): grouped_df = X.groupby(group_by_keys, sort=False) elif _is_spark_df(X): - X = X.drop(*get_index_names(X)) + idx_names = get_index_names(X) + X = X.drop(*(idx_names if idx_names is not None else [])) grouped_df = X.groupby(group_by_keys) else: raise ValueError( diff --git a/lale/lib/rasl/join.py b/lale/lib/rasl/join.py index 6efae032c..44001714f 100644 --- a/lale/lib/rasl/join.py +++ b/lale/lib/rasl/join.py @@ -11,13 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Iterable, List, Optional, Set +from typing import Any, Iterable, List, Literal, Optional, Set import pandas as pd import lale.docstrings import lale.operators -from lale.datasets.data_schemas import add_table_name, get_table_name +from lale.datasets.data_schemas import ( # pylint:disable=ungrouped-imports + SparkDataFrameWithIndex, + add_table_name, + get_table_name, +) from lale.helpers import ( _get_subscript_value, _is_ast_attribute, @@ -31,14 +35,10 @@ try: from pyspark.sql.functions import col - from lale.datasets.data_schemas import ( # pylint:disable=ungrouped-imports - SparkDataFrameWithIndex, - ) - - spark_installed = True - except ImportError: - spark_installed = False + col = None + +MergeHow = Literal["left", "right", "inner", "outer", "cross"] class _JoinImpl: @@ -48,7 +48,7 @@ def __init__( pred=None, join_limit=None, sliding_window_length=None, - join_type="inner", + join_type: MergeHow = "inner", name=None, ): self.pred = pred @@ -167,6 +167,8 @@ def join_df(left_df, right_df): left_table = left_df.alias("left_table") right_table = right_df.alias("right_table") + assert col is not None + for k, key in enumerate(left_key_col): on.append( col(f"{'left_table'}.{key}").eqNullSafe( diff --git a/lale/lib/rasl/map.py b/lale/lib/rasl/map.py index 9e1c41f5d..4e3469a39 100644 --- a/lale/lib/rasl/map.py +++ b/lale/lib/rasl/map.py @@ -33,12 +33,9 @@ from lale.lib.rasl._eval_spark_df import eval_expr_spark_df try: - # noqa in the imports here because those get used dynamically and flake fails. - from pyspark.sql.functions import col as spark_col # noqa - - spark_installed = True + from pyspark.sql.functions import col as spark_col except ImportError: - spark_installed = False + spark_col = None def _new_column_name(name, expr): @@ -179,6 +176,8 @@ def get_map_function_output(column, new_column_name): return mapped_df def transform_spark_df(self, X): + assert spark_col is not None + new_columns = [] accessed_column_names = set() diff --git a/lale/lib/rasl/metrics.py b/lale/lib/rasl/metrics.py index 493a01b2b..7744bf3bc 100644 --- a/lale/lib/rasl/metrics.py +++ b/lale/lib/rasl/metrics.py @@ -38,8 +38,9 @@ if spark_installed: from pyspark.sql.dataframe import DataFrame as SparkDataFrame + from pyspark.sql.dataframe import DataFrame as SparkDataFrameT - _SparkBatch: TypeAlias = Tuple[SparkDataFrame, SparkDataFrame] + _SparkBatch: TypeAlias = Tuple[SparkDataFrameT, SparkDataFrameT] _Batch_XyAux = Union[_PandasBatch, _SparkBatch] @@ -56,6 +57,8 @@ Union[pd.Series, np.ndarray], Union[pd.Series, np.ndarray], pd.DataFrame ] + SparkDataFrame = None + # pyright does not currently accept a TypeAlias with conditional definitions _Batch_Xy: TypeAlias = _Batch_XyAux # type: ignore _Batch_yyX: TypeAlias = _Batch_yyXAux # type: ignore @@ -117,7 +120,7 @@ def make_series_y(y): series = pd.Series(y) elif isinstance(y, pd.DataFrame): series = y.squeeze() - elif spark_installed and isinstance(y, SparkDataFrame): + elif SparkDataFrame is not None and isinstance(y, SparkDataFrame): series = cast(pd.DataFrame, y.toPandas()).squeeze() else: series = y diff --git a/lale/lib/rasl/monoid.py b/lale/lib/rasl/monoid.py index 0d931b818..1de29e839 100644 --- a/lale/lib/rasl/monoid.py +++ b/lale/lib/rasl/monoid.py @@ -38,7 +38,7 @@ def combine(self: _SelfType, other: _SelfType) -> _SelfType: pass @property - def is_absorbing(self): + def is_absorbing(self) -> bool: """ A monoid value `x` is absorbing if for all `y`, `x.combine(y) == x`. This can help stop training early for monoids with learned coefficients. diff --git a/lale/lib/rasl/scores.py b/lale/lib/rasl/scores.py index e86cd35bf..21897600b 100644 --- a/lale/lib/rasl/scores.py +++ b/lale/lib/rasl/scores.py @@ -225,5 +225,5 @@ def to_monoid(self, batch: Tuple[Any, Any]) -> FOnewayData: X, y = batch return _f_oneway_lift(X, y) - def from_monoid(self, monoid: FOnewayData): + def from_monoid(self, monoid: FOnewayData): # type: ignore return _f_oneway_lower(monoid) diff --git a/lale/lib/rasl/task_graphs.py b/lale/lib/rasl/task_graphs.py index 523b4db0e..e6e9bf052 100644 --- a/lale/lib/rasl/task_graphs.py +++ b/lale/lib/rasl/task_graphs.py @@ -1243,8 +1243,8 @@ def is_moot(task2): # same modulo batch_ids input_X = functools.reduce(lambda a, b: a.union(b), list_X) # type: ignore input_y = functools.reduce(lambda a, b: a.union(b), list_y) # type: ignore elif all(isinstance(X, np.ndarray) for X in list_X): - input_X = np.concatenate(list_X) - input_y = np.concatenate(list_y) + input_X = np.concatenate(list_X) # type: ignore + input_y = np.concatenate(list_y) # type: ignore else: raise ValueError( f"""Input of {type(list_X[0])} is not supported for diff --git a/lale/lib/snapml/batched_tree_ensemble_classifier.py b/lale/lib/snapml/batched_tree_ensemble_classifier.py index d9191e772..810ef8f34 100644 --- a/lale/lib/snapml/batched_tree_ensemble_classifier.py +++ b/lale/lib/snapml/batched_tree_ensemble_classifier.py @@ -11,19 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -try: - import snapml # type: ignore - - snapml_installed = True -except ImportError: - snapml_installed = False - import pandas as pd import lale.datasets.data_schemas import lale.docstrings import lale.operators +try: + from snapml import BatchedTreeEnsembleClassifier as Base +except ImportError: + Base = None + def _ensure_numpy(data): if isinstance(data, (pd.DataFrame, pd.Series)): @@ -34,13 +32,13 @@ def _ensure_numpy(data): class _BatchedTreeEnsembleClassifierImpl: def __init__(self, **hyperparams): assert ( - snapml_installed + Base is not None ), """Your Python environment does not have snapml installed. Install using: pip install snapml""" if hyperparams.get("base_ensemble", None) is None: from snapml import SnapBoostingMachineClassifier hyperparams["base_ensemble"] = SnapBoostingMachineClassifier() - self._wrapped_model = snapml.BatchedTreeEnsembleClassifier(**hyperparams) + self._wrapped_model = Base(**hyperparams) def fit(self, X, y, **fit_params): X = _ensure_numpy(X) diff --git a/lale/lib/snapml/batched_tree_ensemble_regressor.py b/lale/lib/snapml/batched_tree_ensemble_regressor.py index 1f7293419..5cdaa83b8 100644 --- a/lale/lib/snapml/batched_tree_ensemble_regressor.py +++ b/lale/lib/snapml/batched_tree_ensemble_regressor.py @@ -11,19 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -try: - import snapml # type: ignore - - snapml_installed = True -except ImportError: - snapml_installed = False - import pandas as pd import lale.datasets.data_schemas import lale.docstrings import lale.operators +try: + from snapml import BatchedTreeEnsembleRegressor as Base +except ImportError: + Base = None + def _ensure_numpy(data): if isinstance(data, (pd.DataFrame, pd.Series)): @@ -34,13 +32,13 @@ def _ensure_numpy(data): class _BatchedTreeEnsembleRegressorImpl: def __init__(self, **hyperparams): assert ( - snapml_installed + Base is not None ), """Your Python environment does not have snapml installed. Install using: pip install snapml""" if hyperparams.get("base_ensemble") is None: from snapml import SnapBoostingMachineRegressor hyperparams["base_ensemble"] = SnapBoostingMachineRegressor() - self._wrapped_model = snapml.BatchedTreeEnsembleRegressor(**hyperparams) + self._wrapped_model = Base(**hyperparams) def fit(self, X, y, **fit_params): X = _ensure_numpy(X) diff --git a/lale/lib/snapml/snap_boosting_machine_classifier.py b/lale/lib/snapml/snap_boosting_machine_classifier.py index 3cffb7d35..44b37cedd 100644 --- a/lale/lib/snapml/snap_boosting_machine_classifier.py +++ b/lale/lib/snapml/snap_boosting_machine_classifier.py @@ -13,24 +13,25 @@ # limitations under the License. from packaging import version +import lale.datasets.data_schemas +import lale.docstrings +import lale.operators + try: - import snapml # type: ignore + import snapml + from snapml import SnapBoostingMachineClassifier as Base snapml_version = version.parse(getattr(snapml, "__version__")) except ImportError: + Base = None snapml_version = None -import lale.datasets.data_schemas -import lale.docstrings -import lale.operators - - class _SnapBoostingMachineClassifierImpl: def __init__(self, **hyperparams): assert ( - snapml_version is not None + snapml_version is not None and Base is not None ), """Your Python environment does not have snapml installed. Install using: pip install snapml""" if ( @@ -39,7 +40,7 @@ def __init__(self, **hyperparams): ): hyperparams["gpu_ids"] = [0] - self._wrapped_model = snapml.SnapBoostingMachineClassifier(**hyperparams) + self._wrapped_model = Base(**hyperparams) def fit(self, X, y, **fit_params): X = lale.datasets.data_schemas.strip_schema(X) diff --git a/lale/lib/snapml/snap_boosting_machine_regressor.py b/lale/lib/snapml/snap_boosting_machine_regressor.py index 5d2344c02..d7e93fca1 100644 --- a/lale/lib/snapml/snap_boosting_machine_regressor.py +++ b/lale/lib/snapml/snap_boosting_machine_regressor.py @@ -13,25 +13,27 @@ # limitations under the License. from packaging import version +import lale.datasets.data_schemas +import lale.docstrings +import lale.operators + try: - import snapml # type: ignore + import snapml + from snapml import SnapBoostingMachineRegressor as Base snapml_version = version.parse(getattr(snapml, "__version__")) except ImportError: + Base = None snapml_version = None -import lale.datasets.data_schemas -import lale.docstrings -import lale.operators - class _SnapBoostingMachineRegressorImpl: def __init__(self, **hyperparams): assert ( - snapml_version is not None + snapml_version is not None and Base is not None ), """Your Python environment does not have snapml installed. Install using: pip install snapml""" - self._wrapped_model = snapml.SnapBoostingMachineRegressor(**hyperparams) + self._wrapped_model = Base(**hyperparams) def fit(self, X, y, **fit_params): X = lale.datasets.data_schemas.strip_schema(X) diff --git a/lale/lib/snapml/snap_decision_tree_classifier.py b/lale/lib/snapml/snap_decision_tree_classifier.py index 3eb4d82f7..2dcf33ae0 100644 --- a/lale/lib/snapml/snap_decision_tree_classifier.py +++ b/lale/lib/snapml/snap_decision_tree_classifier.py @@ -11,25 +11,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -try: - import snapml # type: ignore - - snapml_installed = True -except ImportError: - snapml_installed = False import lale.datasets.data_schemas import lale.docstrings import lale.operators +try: + from snapml import SnapDecisionTreeClassifier as Base +except ImportError: + Base = None + class _SnapDecisionTreeClassifierImpl: def __init__(self, **hyperparams): assert ( - snapml_installed + Base is not None ), """Your Python environment does not have snapml installed. Install using: pip install snapml""" - self._wrapped_model = snapml.SnapDecisionTreeClassifier(**hyperparams) + self._wrapped_model = Base(**hyperparams) def fit(self, X, y, **fit_params): X = lale.datasets.data_schemas.strip_schema(X) diff --git a/lale/lib/snapml/snap_decision_tree_regressor.py b/lale/lib/snapml/snap_decision_tree_regressor.py index b970988df..685bbfcbe 100644 --- a/lale/lib/snapml/snap_decision_tree_regressor.py +++ b/lale/lib/snapml/snap_decision_tree_regressor.py @@ -11,25 +11,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -try: - import snapml # type: ignore - - snapml_installed = True -except ImportError: - snapml_installed = False import lale.datasets.data_schemas import lale.docstrings import lale.operators +try: + from snapml import SnapDecisionTreeRegressor as Base +except ImportError: + Base = None + class _SnapDecisionTreeRegressorImpl: def __init__(self, **hyperparams): assert ( - snapml_installed + Base is not None ), """Your Python environment does not have snapml installed. Install using: pip install snapml""" - self._wrapped_model = snapml.SnapDecisionTreeRegressor(**hyperparams) + self._wrapped_model = Base(**hyperparams) def fit(self, X, y, **fit_params): X = lale.datasets.data_schemas.strip_schema(X) diff --git a/lale/lib/snapml/snap_linear_regression.py b/lale/lib/snapml/snap_linear_regression.py index 82c3c313a..00abacc13 100644 --- a/lale/lib/snapml/snap_linear_regression.py +++ b/lale/lib/snapml/snap_linear_regression.py @@ -11,28 +11,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -try: - import snapml # type: ignore - - snapml_installed = True -except ImportError: - snapml_installed = False import lale.datasets.data_schemas import lale.docstrings import lale.operators +try: + from snapml import LinearRegression as Base +except ImportError: + Base = None + class _SnapLinearRegressionImpl: def __init__(self, **hyperparams): assert ( - snapml_installed + Base is not None ), """Your Python environment does not have snapml installed. Install using: pip install snapml""" if hyperparams.get("device_ids", None) is None: hyperparams["device_ids"] = [] - self._wrapped_model = snapml.LinearRegression(**hyperparams) + self._wrapped_model = Base(**hyperparams) def fit(self, X, y, **fit_params): X = lale.datasets.data_schemas.strip_schema(X) diff --git a/lale/lib/snapml/snap_logistic_regression.py b/lale/lib/snapml/snap_logistic_regression.py index 3eaa0f76f..76ee518ca 100644 --- a/lale/lib/snapml/snap_logistic_regression.py +++ b/lale/lib/snapml/snap_logistic_regression.py @@ -11,28 +11,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -try: - import snapml # type: ignore - - snapml_installed = True -except ImportError: - snapml_installed = False import lale.datasets.data_schemas import lale.docstrings import lale.operators +try: + from snapml import SnapLogisticRegression as Base +except ImportError: + Base = None + class _SnapLogisticRegressionImpl: def __init__(self, **hyperparams): assert ( - snapml_installed + Base is not None ), """Your Python environment does not have snapml installed. Install using: pip install snapml""" if hyperparams.get("device_ids", None) is None: hyperparams["device_ids"] = [] - self._wrapped_model = snapml.SnapLogisticRegression(**hyperparams) + self._wrapped_model = Base(**hyperparams) def fit(self, X, y, **fit_params): X = lale.datasets.data_schemas.strip_schema(X) diff --git a/lale/lib/snapml/snap_random_forest_classifier.py b/lale/lib/snapml/snap_random_forest_classifier.py index b0d895182..bae73b7ba 100644 --- a/lale/lib/snapml/snap_random_forest_classifier.py +++ b/lale/lib/snapml/snap_random_forest_classifier.py @@ -13,22 +13,24 @@ # limitations under the License. from packaging import version +import lale.datasets.data_schemas +import lale.docstrings +import lale.operators + try: - import snapml # type: ignore + import snapml + from snapml import SnapRandomForestClassifier as Base snapml_version = version.parse(getattr(snapml, "__version__")) except ImportError: + Base = None snapml_version = None -import lale.datasets.data_schemas -import lale.docstrings -import lale.operators - class _SnapRandomForestClassifierImpl: def __init__(self, **hyperparams): assert ( - snapml_version is not None + snapml_version is not None and Base is not None ), """Your Python environment does not have snapml installed. Install using: pip install snapml""" if ( @@ -39,7 +41,7 @@ def __init__(self, **hyperparams): if hyperparams.get("gpu_ids", None) is None: hyperparams["gpu_ids"] = [0] - self._wrapped_model = snapml.SnapRandomForestClassifier(**hyperparams) + self._wrapped_model = Base(**hyperparams) def fit(self, X, y, **fit_params): X = lale.datasets.data_schemas.strip_schema(X) diff --git a/lale/lib/snapml/snap_random_forest_regressor.py b/lale/lib/snapml/snap_random_forest_regressor.py index 2f79df4b4..8e9e1549d 100644 --- a/lale/lib/snapml/snap_random_forest_regressor.py +++ b/lale/lib/snapml/snap_random_forest_regressor.py @@ -13,22 +13,24 @@ # limitations under the License. from packaging import version +import lale.datasets.data_schemas +import lale.docstrings +import lale.operators + try: - import snapml # type: ignore + import snapml + from snapml import SnapRandomForestRegressor as Base snapml_version = version.parse(getattr(snapml, "__version__")) except ImportError: + Base = None snapml_version = None -import lale.datasets.data_schemas -import lale.docstrings -import lale.operators - class _SnapRandomForestRegressorImpl: def __init__(self, **hyperparams): assert ( - snapml_version is not None + snapml_version is not None and Base is not None ), """Your Python environment does not have snapml installed. Install using: pip install snapml""" if ( @@ -39,7 +41,7 @@ def __init__(self, **hyperparams): if hyperparams.get("gpu_ids", None) is None: hyperparams["gpu_ids"] = [0] - self._wrapped_model = snapml.SnapRandomForestRegressor(**hyperparams) + self._wrapped_model = Base(**hyperparams) def fit(self, X, y, **fit_params): X = lale.datasets.data_schemas.strip_schema(X) diff --git a/lale/lib/snapml/snap_svm_classifier.py b/lale/lib/snapml/snap_svm_classifier.py index fb6d155a1..aee32209d 100644 --- a/lale/lib/snapml/snap_svm_classifier.py +++ b/lale/lib/snapml/snap_svm_classifier.py @@ -13,23 +13,25 @@ # limitations under the License. from packaging import version +import lale.datasets.data_schemas +import lale.docstrings +import lale.operators + try: - import snapml # type: ignore + import snapml + from snapml import SnapSVMClassifier as Base snapml_version = version.parse(getattr(snapml, "__version__")) except ImportError: + Base = None snapml_version = None -import lale.datasets.data_schemas -import lale.docstrings -import lale.operators - class _SnapSVMClassifierImpl: def __init__(self, **hyperparams): assert ( - snapml_version is not None + snapml_version is not None and Base is not None ), """Your Python environment does not have snapml installed. Install using: pip install snapml""" if snapml_version <= version.Version("1.8.0") and "loss" in hyperparams: @@ -38,7 +40,7 @@ def __init__(self, **hyperparams): if hyperparams.get("device_ids", None) is None: hyperparams["device_ids"] = [0] - self._wrapped_model = snapml.SnapSVMClassifier(**hyperparams) + self._wrapped_model = Base(**hyperparams) def fit(self, X, y, **fit_params): X = lale.datasets.data_schemas.strip_schema(X) diff --git a/lale/operators.py b/lale/operators.py index e148b5b64..4b31384f0 100644 --- a/lale/operators.py +++ b/lale/operators.py @@ -1246,7 +1246,7 @@ def predict(self, X: Any, **predict_params) -> Any: pass @abstractmethod - def predict_proba(self, X: Any): + def predict_proba(self, X: Any) -> Any: """Probability estimates for all classes. Parameters @@ -1262,7 +1262,7 @@ def predict_proba(self, X: Any): pass @abstractmethod - def decision_function(self, X: Any): + def decision_function(self, X: Any) -> Any: """Confidence scores for all classes. Parameters @@ -1278,7 +1278,7 @@ def decision_function(self, X: Any): pass @abstractmethod - def score_samples(self, X: Any): + def score_samples(self, X: Any) -> Any: """Scores for each sample in X. The type of scores depends on the operator. Parameters @@ -1294,7 +1294,7 @@ def score_samples(self, X: Any): pass @abstractmethod - def score(self, X: Any, y: Any, **score_params): + def score(self, X: Any, y: Any, **score_params) -> Any: """Performance evaluation with a default metric. Parameters @@ -1314,7 +1314,7 @@ def score(self, X: Any, y: Any, **score_params): pass @abstractmethod - def predict_log_proba(self, X: Any): + def predict_log_proba(self, X: Any) -> Any: """Predicted class log-probabilities for X. Parameters @@ -2375,8 +2375,10 @@ def _validate_hyperparams(self, hp_explicit, hp_all, hp_schema, class_): validate_schema_directly(hp_all, hp_schema) except jsonschema.ValidationError as e_orig: e = e_orig if e_orig.parent is None else e_orig.parent - validate_is_schema(e.schema) - schema = lale.pretty_print.to_string(e.schema) + sch = e.schema + assert isinstance(sch, dict) + validate_is_schema(sch) + schema = lale.pretty_print.to_string(sch) defaults = self.get_defaults() extra_keys = [k for k in hp_explicit.keys() if k not in defaults] @@ -2451,7 +2453,12 @@ def _validate_hyperparams(self, hp_explicit, hp_all, hp_schema, class_): schema = self.get_defaults() elif e.schema_path[0] == "allOf" and int(e.schema_path[1]) != 0: assert e.schema_path[2] == "anyOf" - descr = e.schema["description"] + schema = e.schema + if isinstance(schema, dict): + descr = schema["description"] + else: + descr = "Boolean schema" + if descr.endswith("."): descr = descr[:-1] reason = f"constraint {descr[0].lower()}{descr[1:]}" @@ -3177,12 +3184,12 @@ def __new__(cls, *args, _lale_trained=False, _lale_impl=None, **kwargs): or _lale_trained or (_lale_impl is not None and not hasattr(_lale_impl, "fit")) ): - obj = super().__new__(TrainedIndividualOp) + obj = super().__new__(TrainedIndividualOp) # type: ignore return obj else: # unless _lale_trained=True, we actually want to return a Trainable - obj = super().__new__(TrainableIndividualOp) - # apparently python does not call __ini__ if the type returned is not the + obj = super().__new__(TrainableIndividualOp) # type: ignore + # apparently python does not call __init__ if the type returned is not the # expected type obj.__init__(*args, **kwargs) return obj @@ -4028,7 +4035,24 @@ def _find_source_nodes(self) -> List[OpType_co]: result = [s for s in self.steps_list() if is_source[s]] return result - def _validate_or_transform_schema(self, X: Any, y: Any = None, validate=True): + @overload + def _validate_or_transform_schema( + self, X, *, y: Optional[Any], validate: Literal[False] + ) -> JSON_TYPE: ... + + @overload + def _validate_or_transform_schema( + self, X, *, y: Optional[Any] = None, validate: Literal[False] + ) -> JSON_TYPE: ... + + @overload + def _validate_or_transform_schema( + self, X, y: Optional[Any] = None, validate: Literal[True] = True + ) -> None: ... + + def _validate_or_transform_schema( + self, X: Any, y: Any = None, validate=True + ) -> Optional[JSON_TYPE]: def combine_schemas(schemas): n_datasets = len(schemas) if n_datasets == 1: @@ -4059,7 +4083,9 @@ def combine_schemas(schemas): output_X = operator.transform_schema(input_X) output_y = input_y outputs[operator] = output_X, output_y - if not validate: + if validate: + return None + else: sinks = self._find_sink_nodes() pipeline_outputs = [outputs[sink][0] for sink in sinks] return combine_schemas(pipeline_outputs) @@ -4067,7 +4093,7 @@ def combine_schemas(schemas): def validate_schema(self, X: Any, y: Any = None): self._validate_or_transform_schema(X, y, validate=True) - def transform_schema(self, s_X: JSON_TYPE): + def transform_schema(self, s_X: JSON_TYPE) -> JSON_TYPE: from lale.settings import disable_data_schema_validation if disable_data_schema_validation: @@ -4652,11 +4678,11 @@ def freeze_trained(self) -> "TrainedPipeline": class TrainedPipeline(TrainablePipeline[TrainedOpType_co], TrainedOperator): def __new__(cls, *args, _lale_trained=False, **kwargs): if "steps" not in kwargs or _lale_trained: - obj = super().__new__(TrainedPipeline) + obj = super().__new__(TrainedPipeline) # type: ignore return obj else: # unless _lale_trained=True, we actually want to return a Trainable - obj = super().__new__(TrainablePipeline) + obj = super().__new__(TrainablePipeline) # type: ignore # apparently python does not call __ini__ if the type returned is not the # expected type obj.__init__(*args, **kwargs) diff --git a/lale/search/schema2search_space.py b/lale/search/schema2search_space.py index 02c7a1089..4ef50ffd6 100644 --- a/lale/search/schema2search_space.py +++ b/lale/search/schema2search_space.py @@ -442,7 +442,7 @@ def schemaToSearchSpaceHelper_( if sub_space: return SearchSpaceDict(o) else: - all_keys = list(o.keys()) + all_keys: list[str] = list(o.keys()) all_keys.sort() o_choice = tuple(o.get(k, None) for k in all_keys) return SearchSpaceObject(longName, all_keys, [o_choice]) @@ -488,7 +488,7 @@ def schemaToSearchSpaceHelper_( ) if "anyOf" in schema: - objs = [] + objs: list[dict[str, SearchSpace]] = [] for s_obj in schema["anyOf"]: if "type" in s_obj and s_obj["type"] == "object": o = self.JsonSchemaToSearchSpaceHelper( diff --git a/mypy.ini b/mypy.ini index 0e4df7832..c3deb7cd1 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,5 +1,5 @@ [mypy] -python_version = 3.8 +python_version = 3.9 ignore_missing_imports = True [mypy-pandas] diff --git a/pyrightconfig.json b/pyrightconfig.json index fbe81973d..c096824c7 100644 --- a/pyrightconfig.json +++ b/pyrightconfig.json @@ -1,3 +1,4 @@ { - "exclude": ["build"] + "exclude": ["build"], + "useLibraryCodeForTypes": false } diff --git a/test/test_aif360.py b/test/test_aif360.py index 2e20d1d70..c47dbc422 100644 --- a/test/test_aif360.py +++ b/test/test_aif360.py @@ -40,10 +40,8 @@ try: import tensorflow as tf - - tensorflow_installed = True except ImportError: - tensorflow_installed = False + tf = None import lale.helpers import lale.lib.aif360 @@ -570,7 +568,7 @@ def _attempt_remi_creditg_pd_num( disparate_impact_scorer = lale.lib.aif360.disparate_impact(**fairness_info) di_list = [] for split in splits: - if tensorflow_installed: # for AdversarialDebiasing + if tf is not None: # for AdversarialDebiasing tf.compat.v1.reset_default_graph() tf.compat.v1.disable_eager_execution() train_X = split["train_X"] @@ -603,7 +601,7 @@ def test_disparate_impact_remover_np_num(self): self.assertTrue(0.8 < impact_remi < 1.0, f"impact_remi {impact_remi}") def test_adversarial_debiasing_pd_num(self): - if tensorflow_installed: + if tf is not None: fairness_info = self.creditg_pd_num["fairness_info"] tf.compat.v1.reset_default_graph() trainable_remi = AdversarialDebiasing(**fairness_info) @@ -1462,7 +1460,7 @@ def _attempt_remi_creditg_pd_cat( disparate_impact_scorer = lale.lib.aif360.disparate_impact(**fairness_info) di_list = [] for split in splits: - if tensorflow_installed: # for AdversarialDebiasing + if tf is not None: # for AdversarialDebiasing tf.compat.v1.reset_default_graph() train_X = split["train_X"] train_y = split["train_y"] @@ -1479,7 +1477,7 @@ def _attempt_remi_creditg_pd_cat( ) def test_adversarial_debiasing_pd_cat(self): - if tensorflow_installed: + if tf is not None: fairness_info = self.creditg_pd_cat["fairness_info"] trainable_remi = AdversarialDebiasing( **fairness_info, preparation=self.prep_pd_cat diff --git a/test/test_aif360_ensembles.py b/test/test_aif360_ensembles.py index f1867f61d..4f509e74c 100644 --- a/test/test_aif360_ensembles.py +++ b/test/test_aif360_ensembles.py @@ -18,9 +18,8 @@ try: import tensorflow as tf - tensorflow_installed = True except ImportError: - tensorflow_installed = False + tf = None from lale.helpers import with_fixed_estimator_name from lale.lib.aif360 import ( @@ -94,7 +93,7 @@ def test_bagging_in_estimator_mitigation_base(self): self._attempt_fit_predict(model) def test_bagging_in_estimator_mitigation_base_1(self): - if tensorflow_installed: + if tf is not None: tf.compat.v1.disable_eager_execution() model = BaggingClassifier( **with_fixed_estimator_name( diff --git a/test/test_core_transformers.py b/test/test_core_transformers.py index 2db7812da..70ffd3908 100644 --- a/test/test_core_transformers.py +++ b/test/test_core_transformers.py @@ -25,7 +25,7 @@ import lale.type_checking from lale.datasets import pandas2spark from lale.datasets.data_schemas import add_table_name, get_table_name -from lale.datasets.util import spark_installed +from lale.helpers import spark_installed from lale.lib.lale import ConcatFeatures from lale.lib.sklearn import ( NMF, diff --git a/test/test_relational.py b/test/test_relational.py index a70ac798d..73756300d 100644 --- a/test/test_relational.py +++ b/test/test_relational.py @@ -13,6 +13,7 @@ # limitations under the License. import unittest +from test import EnableSchemaValidation # pylint:disable=wrong-import-order from typing import List import jsonschema @@ -22,25 +23,9 @@ from sklearn.model_selection import train_test_split import lale.operators -from lale.lib.rasl.convert import Convert -from lale.operator_wrapper import wrap_imported_operators - -try: - from pyspark import SparkConf, SparkContext - from pyspark.sql import Row, SparkSession, SQLContext - - from lale.datasets.data_schemas import ( # pylint:disable=ungrouped-imports - SparkDataFrameWithIndex, - ) - - spark_installed = True -except ImportError: - spark_installed = False - -from test import EnableSchemaValidation # pylint:disable=wrong-import-order - from lale.datasets import pandas2spark from lale.datasets.data_schemas import ( + SparkDataFrameWithIndex, add_table_name, get_index_name, get_table_name, @@ -98,7 +83,18 @@ Scan, SortIndex, ) +from lale.lib.rasl.convert import Convert from lale.lib.sklearn import PCA, KNeighborsClassifier, LogisticRegression +from lale.operator_wrapper import wrap_imported_operators + +try: + from pyspark import SparkConf, SparkContext + from pyspark.sql import SparkSession, SQLContext +except ImportError: + SparkConf = None + SparkContext = None + SparkSession = None + SQLContext = None def _set_index_name(df, name): @@ -171,7 +167,12 @@ def setUpClass(cls): (5, "CA", 0, float(3)), ] - if spark_installed: + if SparkConf is not None: + from pyspark.sql import Row + + assert SparkContext is not None + assert SQLContext is not None + conf = ( SparkConf() .setMaster("local[2]") @@ -2388,9 +2389,10 @@ def test_str1(self): class TestSplitXy(unittest.TestCase): @classmethod - def setUp(cls): # pylint:disable=arguments-differ + def setUpClass(cls): data = load_iris() - X, y = data.data, data.target + X, y = data.data, data.target # type: ignore + X_train, _X_test, y_train, _y_test = train_test_split( pd.DataFrame(X), pd.DataFrame(y) )