@@ -48,15 +48,16 @@ def __init__(
4848 super ().__init__ (** kwargs )
4949 self .save_hyperparameters (ignore = ["cat_feature_info" , "num_feature_info" ])
5050
51- self .lr = self .hparams .get ("lr" , config .lr )
52- self .lr_patience = self .hparams .get ("lr_patience" , config .lr_patience )
53- self .weight_decay = self .hparams .get ("weight_decay" , config .weight_decay )
54- self .lr_factor = self .hparams .get ("lr_factor" , config .lr_factor )
55- self .cat_feature_info = cat_feature_info
56- self .num_feature_info = num_feature_info
57- self .num_classes = num_classes
58- self .interaction_degree = self .hparams .get (
59- "interaction_degree" , config .interaction_degree
51+ self .lr = self .hparams .get ("lr" , config .lr )
52+ self .lr_patience = self .hparams .get ("lr_patience" , config .lr_patience )
53+ self .weight_decay = self .hparams .get ("weight_decay" , config .weight_decay )
54+ self .lr_factor = self .hparams .get ("lr_factor" , config .lr_factor )
55+ self .l2_lambda = self .hparams .get ("l2_lambda" , config .l2_lambda )
56+ self .cat_feature_info = cat_feature_info
57+ self .num_feature_info = num_feature_info
58+ self .num_classes = num_classes
59+ self .interaction_degree = self .hparams .get (
60+ "interaction_degree" , config .interaction_degree
6061 )
6162
6263 # Calculate total input dimension
@@ -128,11 +129,19 @@ def forward(self, num_features: dict, cat_features: dict) -> dict:
128129 # Apply feature dropout
129130 x = self .feature_dropout (x )
130131
131- # Get prediction from the model
132- output = self .model (x )
133-
134- # Create result dictionary
135- result = {"output" : output }
132+ # Get prediction (and optional regularization penalty) from the model
133+ penalty = None
134+ if self .l2_lambda and self .l2_lambda > 0 :
135+ output = self .model (x , return_outputs_penalty = True )
136+ if isinstance (output , tuple ):
137+ output , penalty = output
138+ else :
139+ output = self .model (x )
140+
141+ # Create result dictionary
142+ result = {"output" : output }
143+ if penalty is not None :
144+ result ["penalty" ] = penalty
136145
137146 # Add individual feature outputs for interpretability
138147 for i , feature_name in enumerate (feature_names ):
0 commit comments