Skip to content

Commit 42b4a7c

Browse files
author
ananyapam
committed
Update nodegam.py: fix syntax and code improvements
1 parent e8d019c commit 42b4a7c

1 file changed

Lines changed: 23 additions & 14 deletions

File tree

nampy/basemodels/nodegam.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,16 @@ def __init__(
4848
super().__init__(**kwargs)
4949
self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"])
5050

51-
self.lr = self.hparams.get("lr", config.lr)
52-
self.lr_patience = self.hparams.get("lr_patience", config.lr_patience)
53-
self.weight_decay = self.hparams.get("weight_decay", config.weight_decay)
54-
self.lr_factor = self.hparams.get("lr_factor", config.lr_factor)
55-
self.cat_feature_info = cat_feature_info
56-
self.num_feature_info = num_feature_info
57-
self.num_classes = num_classes
58-
self.interaction_degree = self.hparams.get(
59-
"interaction_degree", config.interaction_degree
51+
self.lr = self.hparams.get("lr", config.lr)
52+
self.lr_patience = self.hparams.get("lr_patience", config.lr_patience)
53+
self.weight_decay = self.hparams.get("weight_decay", config.weight_decay)
54+
self.lr_factor = self.hparams.get("lr_factor", config.lr_factor)
55+
self.l2_lambda = self.hparams.get("l2_lambda", config.l2_lambda)
56+
self.cat_feature_info = cat_feature_info
57+
self.num_feature_info = num_feature_info
58+
self.num_classes = num_classes
59+
self.interaction_degree = self.hparams.get(
60+
"interaction_degree", config.interaction_degree
6061
)
6162

6263
# Calculate total input dimension
@@ -128,11 +129,19 @@ def forward(self, num_features: dict, cat_features: dict) -> dict:
128129
# Apply feature dropout
129130
x = self.feature_dropout(x)
130131

131-
# Get prediction from the model
132-
output = self.model(x)
133-
134-
# Create result dictionary
135-
result = {"output": output}
132+
# Get prediction (and optional regularization penalty) from the model
133+
penalty = None
134+
if self.l2_lambda and self.l2_lambda > 0:
135+
output = self.model(x, return_outputs_penalty=True)
136+
if isinstance(output, tuple):
137+
output, penalty = output
138+
else:
139+
output = self.model(x)
140+
141+
# Create result dictionary
142+
result = {"output": output}
143+
if penalty is not None:
144+
result["penalty"] = penalty
136145

137146
# Add individual feature outputs for interpretability
138147
for i, feature_name in enumerate(feature_names):

0 commit comments

Comments
 (0)