diff --git a/m2cgen/assemblers/__init__.py b/m2cgen/assemblers/__init__.py index a817d331..791ccb85 100644 --- a/m2cgen/assemblers/__init__.py +++ b/m2cgen/assemblers/__init__.py @@ -40,6 +40,7 @@ # LightGBM "lightgbm_LGBMClassifier": LightGBMModelAssembler, "lightgbm_LGBMRegressor": LightGBMModelAssembler, + "lightgbm_Booster": LightGBMModelAssembler, # XGBoost "xgboost_XGBClassifier": XGBoostModelAssemblerSelector, diff --git a/m2cgen/assemblers/boosting.py b/m2cgen/assemblers/boosting.py index 66bf8baf..5ec6c105 100644 --- a/m2cgen/assemblers/boosting.py +++ b/m2cgen/assemblers/boosting.py @@ -219,7 +219,14 @@ class LightGBMModelAssembler(BaseTreeBoostingAssembler): classifier_names = {"LGBMClassifier"} def __init__(self, model): - model_dump = model.booster_.dump_model() + if hasattr(model, "booster_"): + # Scikit-learn interface (i.g. lightgbm.LGBMClassifier, lightgbm.LGBMRegressor) + # https://lightgbm.readthedocs.io/en/stable/Python-API.html#scikit-learn-api + model_dump = model.booster_.dump_model() + else: + # Python-API interface (i.g. lightgbm.train) + # https://lightgbm.readthedocs.io/en/stable/pythonapi/lightgbm.train.html + model_dump = model.dump_model() trees = [m["tree_structure"] for m in model_dump["tree_info"]] self.n_iter = len(trees) // model_dump["num_tree_per_iteration"] diff --git a/tests/assemblers/test_boosting_lightgbm.py b/tests/assemblers/test_boosting_lightgbm.py index 743c83c5..a2e0c7e8 100644 --- a/tests/assemblers/test_boosting_lightgbm.py +++ b/tests/assemblers/test_boosting_lightgbm.py @@ -41,6 +41,39 @@ def test_binary_classification(): assert utils.cmp_exprs(actual, expected) +def test_binary_classification_booster(): + estimator = lgb.LGBMClassifier(n_estimators=2, random_state=1, max_depth=1) + utils.get_binary_classification_model_trainer()(estimator) + + assembler = LightGBMModelAssembler(estimator.booster_) + actual = assembler.assemble() + + sigmoid = ast.SigmoidExpr( + ast.BinNumExpr( + ast.IfExpr( + ast.CompExpr( + ast.FeatureRef(20), + ast.NumVal(16.795), + ast.CompOpType.GT), + ast.NumVal(0.27502096830384837), + ast.NumVal(0.6391171126839048)), + ast.IfExpr( + ast.CompExpr( + ast.FeatureRef(27), + ast.NumVal(0.14205), + ast.CompOpType.GT), + ast.NumVal(-0.21340153096570616), + ast.NumVal(0.11583109256834748)), + ast.BinNumOpType.ADD), + to_reuse=True) + + expected = ast.VectorVal([ + ast.BinNumExpr(ast.NumVal(1), sigmoid, ast.BinNumOpType.SUB), + sigmoid]) + + assert utils.cmp_exprs(actual, expected) + + def test_multi_class(): estimator = lgb.LGBMClassifier(n_estimators=1, random_state=1, max_depth=1) estimator.fit(np.array([[1], [2], [3]]), np.array([1, 2, 3]))