Skip to content

Commit 7dbf427

Browse files
authored
Merge pull request #50 from SoftwareAG/xgboost-exporter-performance
Performance improvement for xgboost exporter
2 parents cd29b9d + e463bf1 commit 7dbf427

1 file changed

Lines changed: 4 additions & 12 deletions

File tree

nyoka/xgboost/xgboost_to_pmml.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -259,10 +259,7 @@ def get_segments_for_xgbr(model, derived_col_names, feature_names, target_name,
259259
Nyoka's Segment object
260260
261261
"""
262-
segments = list()
263-
get_nodes_in_json_format = []
264-
for i in range(model.n_estimators):
265-
get_nodes_in_json_format.append(json.loads(model._Booster.get_dump(dump_format='json')[i]))
262+
get_nodes_in_json_format = model._Booster.get_dump(dump_format='json')
266263
segmentation = pml.Segmentation(multipleModelMethod=MULTIPLE_MODEL_METHOD.SUM,
267264
Segment=generate_Segments_Equal_To_Estimators(get_nodes_in_json_format, derived_col_names,
268265
feature_names))
@@ -373,7 +370,7 @@ def generate_Segments_Equal_To_Estimators(val, derived_col_names, col_names):
373370
main_node = pml.Node(True_=pml.True_())
374371
m_flds = []
375372
mining_field_for_innner_segments = col_names
376-
create_node(val[i], main_node, derived_col_names)
373+
create_node(json.loads(val[i]), main_node, derived_col_names)
377374

378375
for name in mining_field_for_innner_segments:
379376
m_flds.append(pml.MiningField(name=name))
@@ -455,9 +452,7 @@ def get_segments_for_xgbc(model, derived_col_names, feature_names, target_name,
455452
segments = list()
456453

457454
if model.n_classes_ == 2:
458-
get_nodes_in_json_format=[]
459-
for i in range(model.n_estimators):
460-
get_nodes_in_json_format.append(json.loads(model._Booster.get_dump(dump_format='json')[i]))
455+
get_nodes_in_json_format=model._Booster.get_dump(dump_format='json')
461456
mining_schema_for_1st_segment = mining_Field_For_First_Segment(feature_names)
462457
outputField = list()
463458
outputField.append(pml.OutputField(name="xgbValue", optype=OPTYPE.CONTINUOUS, dataType=DATATYPE.FLOAT,
@@ -476,10 +471,7 @@ def get_segments_for_xgbc(model, derived_col_names, feature_names, target_name,
476471

477472
segments.append(last_segment)
478473
else:
479-
480-
get_nodes_in_json_format = []
481-
for i in range(model.n_estimators * model.n_classes_):
482-
get_nodes_in_json_format.append(json.loads(model._Booster.get_dump(dump_format='json')[i]))
474+
get_nodes_in_json_format = model._Booster.get_dump(dump_format='json')
483475
oField = list()
484476
for index in range(0, model.n_classes_):
485477
inner_segment = []

0 commit comments

Comments
 (0)