From eeb1fe12c94d484aa68a9fe878ed4860988b5a2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=A1=D1=82=D0=B5=D1=84=D0=B0=D0=BD=20=D0=91=D0=B8=D0=B4?= =?UTF-8?q?=D0=B6=D0=B0=D0=BC=D0=BE=D0=B2?= Date: Fri, 12 Jun 2026 00:55:01 +0300 Subject: [PATCH] Fix TabM backbone parameter propagation --- lightautoml/ml_algo/dl_model.py | 7 ++- lightautoml/text/nn_model.py | 11 +++-- .../test_presets/test_tabularautoml_tabm.py | 46 +++++++++++++++++++ 3 files changed, 58 insertions(+), 6 deletions(-) diff --git a/lightautoml/ml_algo/dl_model.py b/lightautoml/ml_algo/dl_model.py index b6e18fc4..6fc1cc43 100644 --- a/lightautoml/ml_algo/dl_model.py +++ b/lightautoml/ml_algo/dl_model.py @@ -393,6 +393,11 @@ def _infer_params(self, train=None, valid=None): return model + def _get_tabm_k(self) -> int: + """Get TabM ensemble size from model parameters.""" + backbone_params = self.params.get("backbone_params") or {} + return backbone_params.get("k", self.params.get("k", 32)) + @staticmethod def get_mean_target(target, task_name: str): """Get target mean / inverse sigmoid transformation \ @@ -558,7 +563,7 @@ def get_dataloaders_from_dicts(self, data_dict: Dict): dataset_size=len(dataset), batch_size=self.train_params["bs"], shuffle=is_shuffle(stage), - k=self.params.get("k", 32), + k=self._get_tabm_k(), share_training_batches=self.params.get("share_training_batches", True), device=self.params.get("device", torch.device("cuda:0")), ), diff --git a/lightautoml/text/nn_model.py b/lightautoml/text/nn_model.py index 42265bd5..29e0d8c7 100644 --- a/lightautoml/text/nn_model.py +++ b/lightautoml/text/nn_model.py @@ -192,6 +192,11 @@ def loss_fn_wrapper( self.text_embedder = text_embedder(**text_params) n_in += self.text_embedder.get_out_shape() + backbone_params = dict(kwargs.get("backbone_params") or {}) + backbone_params["start_scaling_init_chunks"] = ( + start_scaling_init_chunks if len(start_scaling_init_chunks) > 0 else None + ) + self.torch_model = ( torch_model( **{ @@ -201,11 +206,7 @@ def loss_fn_wrapper( "n_out": n_out, "loss": loss, "task": task, - "backbone_params": { - "start_scaling_init_chunks": start_scaling_init_chunks - if len(start_scaling_init_chunks) > 0 - else None - }, + "backbone_params": backbone_params, }, } ) diff --git a/tests/unit/test_automl/test_presets/test_tabularautoml_tabm.py b/tests/unit/test_automl/test_presets/test_tabularautoml_tabm.py index 39d7977e..a6d2c52b 100644 --- a/tests/unit/test_automl/test_presets/test_tabularautoml_tabm.py +++ b/tests/unit/test_automl/test_presets/test_tabularautoml_tabm.py @@ -3,12 +3,58 @@ from sklearn.metrics import roc_auc_score import pytest +import torch.nn as nn +from lightautoml.ml_algo.dl_model import TorchModel from lightautoml.automl.presets.tabular_presets import TabularAutoML +from lightautoml.tasks import Task +from lightautoml.text.embed import ContEmbedder +from lightautoml.text.nn_model import TorchUniversalModel from tests.unit.test_automl.test_presets.presets_utils import check_pickling from tests.unit.test_automl.test_presets.presets_utils import get_target_name +class _BackboneParamsRecorder(nn.Module): + """Tiny model used to inspect parameters passed by TorchUniversalModel.""" + + def __init__(self, n_in, n_out, backbone_params=None, **kwargs): + super().__init__() + self.backbone_params = backbone_params + self.linear = nn.Linear(n_in, n_out) + + def forward(self, x): + return self.linear(x) + + +def test_torch_universal_model_preserves_backbone_params(): + """Test that wrapper adds TabM chunks without dropping user backbone params.""" + task = Task("binary") + + model = TorchUniversalModel( + task=task, + loss=task.losses["torch"].loss, + torch_model=_BackboneParamsRecorder, + n_out=1, + cont_embedder_=ContEmbedder, + cont_params={"num_dims": 3, "input_bn": False, "embedding_size": 1}, + backbone_params={"k": 4, "arch_type": "tabm-mini"}, + ) + + backbone_params = model.torch_model.backbone_params + assert backbone_params["k"] == 4 + assert backbone_params["arch_type"] == "tabm-mini" + assert backbone_params["start_scaling_init_chunks"] == [1, 1, 1] + + +def test_torch_model_uses_tabm_k_from_backbone_params(): + """Test that TabM sampler follows the model ensemble size.""" + model = TorchModel(default_params={"backbone_params": {"k": 7}}) + assert model._get_tabm_k() == 7 + + model = TorchModel(default_params={"k": 5}) + assert model._get_tabm_k() == 5 + + class TestTabM: """Neural network test based on out-of-fold and test scores."""