Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,4 @@ repos:
language: node
pass_filenames: false
types: [python]
additional_dependencies: ['pyright@1.1.291']
additional_dependencies: ['pyright@1.1.398']
4 changes: 2 additions & 2 deletions lale/datasets/data_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,8 @@ def get_index_name(obj):
return result


def get_index_names(obj):
result = None
def get_index_names(obj) -> Optional[list[str]]:
result: Optional[list[str]] = None
if SparkDataFrame is not None and isinstance(obj, SparkDataFrameWithIndex):
result = obj.index_names
elif isinstance(
Expand Down
23 changes: 12 additions & 11 deletions lale/datasets/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,27 @@
import pandas as pd
from sklearn.utils import Bunch

from lale.datasets.data_schemas import add_table_name, get_table_name

download_data_cache_dir: pathlib.Path = pathlib.Path(
os.environ.get("LALE_DOWNLOAD_CACHE_DIR", os.path.dirname(__file__))
from lale.datasets.data_schemas import (
SparkDataFrameWithIndex,
add_table_name,
get_table_name,
)

try:
from pyspark.sql import SparkSession

from lale.datasets.data_schemas import ( # pylint:disable=ungrouped-imports
SparkDataFrameWithIndex,
)

spark_installed = True
except ImportError:
spark_installed = False
SparkSession = None


download_data_cache_dir: pathlib.Path = pathlib.Path(
os.environ.get("LALE_DOWNLOAD_CACHE_DIR", os.path.dirname(__file__))
)


def pandas2spark(pandas_df):
assert spark_installed
assert SparkSession is not None

spark_session = (
SparkSession.builder.master("local[2]") # type: ignore
.config("spark.driver.memory", "64g")
Expand Down
6 changes: 6 additions & 0 deletions lale/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Due to changes in the ast package, the types are confused, but the code still works
# changing the types presents compatibility challenges for now, so some pyright errors are disabled for now
# pyright: reportArgumentType=false
# pyright: reportIncompatibleMethodOverride=false


import ast # see also https://greentreesnakes.readthedocs.io/
import pprint
import typing
Expand Down
7 changes: 3 additions & 4 deletions lale/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@
spark_installed = spark_loader is not None
if spark_installed:
from pyspark.sql.dataframe import DataFrame as spark_df
else:
spark_df = None

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -1252,10 +1254,7 @@ def _is_spark_df(df):


def _is_spark_df_without_index(df):
if spark_installed:
return isinstance(df, spark_df) and not _is_spark_df(df)
else:
return False
return spark_df is not None and isinstance(df, spark_df) and not _is_spark_df(df)


def _ensure_pandas(df) -> pd.DataFrame:
Expand Down
6 changes: 4 additions & 2 deletions lale/lib/category_encoders/target_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@

try:
import category_encoders
from category_encoders import TargetEncoder as Base

catenc_version = version.parse(getattr(category_encoders, "__version__"))

except ImportError:
Base = None
catenc_version = None

import lale.docstrings
Expand Down Expand Up @@ -169,11 +171,11 @@

class _TargetEncoderImpl:
def __init__(self, **hyperparams):
if catenc_version is None:
if catenc_version is None or Base is None:
raise ValueError("The package 'category_encoders' is not installed.")
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=FutureWarning)
self._wrapped_model = category_encoders.TargetEncoder(**hyperparams)
self._wrapped_model = Base(**hyperparams)

def fit(self, X, y):
if catenc_version is None:
Expand Down
8 changes: 3 additions & 5 deletions lale/lib/lale/grid_search_cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,11 @@

from .observing import Observing

func_timeout_installed = False
try:
from func_timeout import FunctionTimedOut, func_timeout

func_timeout_installed = True
except ImportError:
pass
FunctionTimedOut = None
func_timeout = None


class _GridSearchCVImpl:
Expand Down Expand Up @@ -152,7 +150,7 @@ def fit(self, X, y, **fit_params):
n_jobs=self._hyperparams["n_jobs"],
)
if self._hyperparams["max_opt_time"] is not None:
if func_timeout_installed:
if func_timeout is not None and FunctionTimedOut is not None:
try:
func_timeout(
self._hyperparams["max_opt_time"], self.grid.fit, (X, y)
Expand Down
6 changes: 3 additions & 3 deletions lale/lib/lale/halving_grid_search_cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@
try:
from func_timeout import FunctionTimedOut, func_timeout

func_timeout_installed = True
except ImportError:
pass
FunctionTimedOut = None
func_timeout = None


class _HalvingGridSearchCVImpl:
Expand Down Expand Up @@ -170,7 +170,7 @@ def fit(self, X, y, **fit_params):
n_jobs=self._hyperparams["n_jobs"],
)
if self._hyperparams["max_opt_time"] is not None:
if func_timeout_installed:
if func_timeout is not None and FunctionTimedOut is not None:
try:
func_timeout(
self._hyperparams["max_opt_time"], self.grid.fit, (X, y)
Expand Down
33 changes: 17 additions & 16 deletions lale/lib/lale/smac.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,8 @@
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=FutureWarning)
from smac.configspace import ConfigurationSpace

# Import SMAC-utilities
from smac.facade.smac_facade import SMAC as orig_SMAC
from smac.scenario.scenario import Scenario
from smac.tae.execute_ta_run import BudgetExhaustedException

from lale.search.lale_smac import ( # pylint:disable=wrong-import-position,ungrouped-imports
get_smac_space,
lale_op_smac_tae,
lale_trainable_op_from_config,
)

smac_installed = True
except ImportError:
smac_installed = False
ConfigurationSpace = None

logger = logging.getLogger(__name__)

Expand All @@ -72,7 +59,9 @@ def __init__(
max_opt_time=None,
lale_num_grids=None,
):
assert smac_installed, """Your Python environment does not have smac installed. You can install it with
assert (
ConfigurationSpace is not None
), """Your Python environment does not have smac installed. You can install it with
pip install smac<=0.10.0
or with
pip install 'lale[full]'"""
Expand All @@ -98,10 +87,22 @@ def __init__(
self.trials = None

def fit(self, X_train, y_train, **fit_params):

# Import SMAC-utilities
from smac.facade.smac_facade import SMAC as orig_SMAC
from smac.scenario.scenario import Scenario
from smac.tae.execute_ta_run import BudgetExhaustedException

from lale.search.lale_smac import ( # pylint:disable=wrong-import-position,ungrouped-imports
get_smac_space,
lale_op_smac_tae,
lale_trainable_op_from_config,
)

data_schema = lale.helpers.fold_schema(
X_train, y_train, self.cv, self.estimator.is_classifier()
)
self.search_space: ConfigurationSpace = get_smac_space(
self.search_space = get_smac_space(
self.estimator, lale_num_grids=self.lale_num_grids, data_schema=data_schema
)
# Scenario object
Expand Down
76 changes: 54 additions & 22 deletions lale/lib/rasl/_eval_spark_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,13 @@
from lale.helpers import _ast_func_id

try:
import pyspark.sql.functions

# noqa in the imports here because those get used dynamically and flake fails.
from pyspark.sql.functions import col # noqa
from pyspark.sql.functions import lit # noqa
from pyspark.sql.functions import to_timestamp # noqa
from pyspark.sql.functions import hour as spark_hour # noqa
from pyspark.sql.functions import isnan, isnull # noqa
from pyspark.sql.functions import minute as spark_minute # noqa
from pyspark.sql.functions import month as spark_month # noqa
from pyspark.sql.functions import col
from pyspark.sql.functions import hour as spark_hour
from pyspark.sql.functions import isnan, isnull, lit
from pyspark.sql.functions import minute as spark_minute
from pyspark.sql.functions import month as spark_month
from pyspark.sql.functions import to_timestamp
from pyspark.sql.types import LongType

from pyspark.sql.functions import ( # noqa; isort: skip
Expand All @@ -42,6 +39,26 @@

spark_installed = True
except ImportError:
lit = None
col = None
to_timestamp = None
isnan = None
isnull = None
spark_udf = None
spark_floor = None
LongType = None
spark_when = None
dayofmonth = None
dayofweek = None
dayofyear = None
spark_floor = None
spark_md5 = None
spark_udf = None
spark_when = None
spark_hour = None
spark_minute = None
spark_month = None

spark_installed = False


Expand All @@ -59,21 +76,18 @@ class _SparkEvaluator(ast.NodeVisitor):
def __init__(self):
self.result = None

def visit_Num(self, node: ast.Num):
self.result = lit(node.n)

def visit_Str(self, node: ast.Str):
self.result = lit(node.s)

def visit_Constant(self, node: ast.Constant):
assert lit is not None
self.result = lit(node.value)

def visit_Attribute(self, node: ast.Attribute):
column_name = _it_column(node)
assert col is not None
self.result = col(column_name) # type: ignore

def visit_Subscript(self, node: ast.Subscript):
column_name = _it_column(node)
assert col is not None
self.result = col(column_name) # type: ignore

def visit_BinOp(self, node: ast.BinOp):
Expand All @@ -92,6 +106,7 @@ def visit_BinOp(self, node: ast.BinOp):
elif isinstance(node.op, ast.Div):
self.result = v1 / v2
elif isinstance(node.op, ast.FloorDiv):
assert spark_floor is not None
self.result = spark_floor(v1 / v2)
elif isinstance(node.op, ast.Mod):
self.result = v1 % v2
Expand Down Expand Up @@ -155,7 +170,9 @@ def hash(call: ast.Call): # pylint:disable=redefined-builtin
hashing_method = ast.literal_eval(call.args[0])
column = _eval_ast_expr_spark_df(call.args[1]) # type: ignore
if hashing_method == "md5":
hash_fun = spark_md5(column) # type: ignore
assert spark_md5 is not None
assert column is not None
hash_fun = spark_md5(column)
else:
raise ValueError(f"Unimplementade hash function in Spark: {hashing_method}")
return hash_fun
Expand All @@ -164,6 +181,8 @@ def hash(call: ast.Call): # pylint:disable=redefined-builtin
def hash_mod(call: ast.Call):
h_column = hash(call)
N = ast.literal_eval(call.args[2])
assert spark_udf is not None
assert LongType is not None
int16_mod_N = spark_udf((lambda x: int(x, 16) % N), LongType())
return int16_mod_N(h_column)

Expand All @@ -188,18 +207,23 @@ def replace(call: ast.Call):

handle_unknown = ast.literal_eval(call.args[2])
chain_of_whens = None
assert column is not None
for key, value in mapping_dict.items():
if key == "nan":
when_expr = isnan(column) # type: ignore
assert isnan is not None
when_expr = isnan(column)
elif key is None:
when_expr = isnull(column) # type: ignore
assert isnull is not None
when_expr = isnull(column)
else:
when_expr = column == key # type: ignore
when_expr = column == key
if chain_of_whens is None:
chain_of_whens = pyspark.sql.functions.when(when_expr, value)
assert spark_when is not None
chain_of_whens = spark_when(when_expr, value)
else:
chain_of_whens = chain_of_whens.when(when_expr, value)
if handle_unknown == "use_encoded_value":
assert lit is not None
fallback = lit(ast.literal_eval(call.args[3]))
else:
fallback = column
Expand All @@ -216,31 +240,39 @@ def identity(call: ast.Call):

def time_functions(call, spark_func):
column = _eval_ast_expr_spark_df(call.args[0])
assert to_timestamp is not None
assert column is not None
if len(call.args) > 1:
fmt = ast.literal_eval(call.args[1])
return spark_func(to_timestamp(column, format=fmt)) # type: ignore
return spark_func(to_timestamp(column)) # type: ignore
return spark_func(to_timestamp(column, format=fmt))
return spark_func(to_timestamp(column))


def day_of_month(call: ast.Call):
assert dayofmonth is not None
return time_functions(call, dayofmonth)


def day_of_week(call: ast.Call):
assert dayofweek is not None
return time_functions(call, dayofweek)


def day_of_year(call: ast.Call):
assert dayofyear is not None
return time_functions(call, dayofyear)


def hour(call: ast.Call):
assert spark_hour is not None
return time_functions(call, spark_hour)


def minute(call: ast.Call):
assert spark_minute is not None
return time_functions(call, spark_minute)


def month(call: ast.Call):
assert spark_month is not None
return time_functions(call, spark_month)
Loading
Loading