Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion lightautoml/ml_algo/dl_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,11 @@ def _infer_params(self, train=None, valid=None):

return model

def _get_tabm_k(self) -> int:
"""Get TabM ensemble size from model parameters."""
backbone_params = self.params.get("backbone_params") or {}
return backbone_params.get("k", self.params.get("k", 32))

@staticmethod
def get_mean_target(target, task_name: str):
"""Get target mean / inverse sigmoid transformation \
Expand Down Expand Up @@ -558,7 +563,7 @@ def get_dataloaders_from_dicts(self, data_dict: Dict):
dataset_size=len(dataset),
batch_size=self.train_params["bs"],
shuffle=is_shuffle(stage),
k=self.params.get("k", 32),
k=self._get_tabm_k(),
share_training_batches=self.params.get("share_training_batches", True),
device=self.params.get("device", torch.device("cuda:0")),
),
Expand Down
11 changes: 6 additions & 5 deletions lightautoml/text/nn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,11 @@ def loss_fn_wrapper(
self.text_embedder = text_embedder(**text_params)
n_in += self.text_embedder.get_out_shape()

backbone_params = dict(kwargs.get("backbone_params") or {})
backbone_params["start_scaling_init_chunks"] = (
start_scaling_init_chunks if len(start_scaling_init_chunks) > 0 else None
)

self.torch_model = (
torch_model(
**{
Expand All @@ -201,11 +206,7 @@ def loss_fn_wrapper(
"n_out": n_out,
"loss": loss,
"task": task,
"backbone_params": {
"start_scaling_init_chunks": start_scaling_init_chunks
if len(start_scaling_init_chunks) > 0
else None
},
"backbone_params": backbone_params,
},
}
)
Expand Down
46 changes: 46 additions & 0 deletions tests/unit/test_automl/test_presets/test_tabularautoml_tabm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,58 @@

from sklearn.metrics import roc_auc_score
import pytest
import torch.nn as nn

from lightautoml.ml_algo.dl_model import TorchModel
from lightautoml.automl.presets.tabular_presets import TabularAutoML
from lightautoml.tasks import Task
from lightautoml.text.embed import ContEmbedder
from lightautoml.text.nn_model import TorchUniversalModel
from tests.unit.test_automl.test_presets.presets_utils import check_pickling
from tests.unit.test_automl.test_presets.presets_utils import get_target_name


class _BackboneParamsRecorder(nn.Module):
"""Tiny model used to inspect parameters passed by TorchUniversalModel."""

def __init__(self, n_in, n_out, backbone_params=None, **kwargs):
super().__init__()
self.backbone_params = backbone_params
self.linear = nn.Linear(n_in, n_out)

def forward(self, x):
return self.linear(x)


def test_torch_universal_model_preserves_backbone_params():
"""Test that wrapper adds TabM chunks without dropping user backbone params."""
task = Task("binary")

model = TorchUniversalModel(
task=task,
loss=task.losses["torch"].loss,
torch_model=_BackboneParamsRecorder,
n_out=1,
cont_embedder_=ContEmbedder,
cont_params={"num_dims": 3, "input_bn": False, "embedding_size": 1},
backbone_params={"k": 4, "arch_type": "tabm-mini"},
)

backbone_params = model.torch_model.backbone_params
assert backbone_params["k"] == 4
assert backbone_params["arch_type"] == "tabm-mini"
assert backbone_params["start_scaling_init_chunks"] == [1, 1, 1]


def test_torch_model_uses_tabm_k_from_backbone_params():
"""Test that TabM sampler follows the model ensemble size."""
model = TorchModel(default_params={"backbone_params": {"k": 7}})
assert model._get_tabm_k() == 7

model = TorchModel(default_params={"k": 5})
assert model._get_tabm_k() == 5


class TestTabM:
"""Neural network test based on out-of-fold and test scores."""

Expand Down