diff --git a/m2cgen/assemblers/boosting.py b/m2cgen/assemblers/boosting.py index 66bf8baf..b55725b0 100644 --- a/m2cgen/assemblers/boosting.py +++ b/m2cgen/assemblers/boosting.py @@ -17,7 +17,10 @@ class BaseBoostingAssembler(ModelAssembler): def __init__(self, model, estimator_params, base_score=0.0): super().__init__(model) self._all_estimator_params = estimator_params - self._base_score = base_score + if base_score is None: + self._base_score = 0.0 + else: + self._base_score = base_score self._output_size = 1 self._is_classification = False @@ -141,10 +144,15 @@ def __init__(self, model): # Limit the number of trees that should be used for # assembling (if applicable). best_ntree_limit = getattr(model, "best_ntree_limit", None) - + if model.get_params().get("base_score") is not None: + base_score = model.get_params()["base_score"] + elif model.intercept_ is not None: + base_score = model.intercept_[0] + else: + base_score = 0.0 super().__init__(model, trees, - base_score=model.get_params()["base_score"], + base_score=base_score, tree_limit=best_ntree_limit) def _assemble_tree(self, tree):