Skip to content

Commit 4230db1

Browse files
ENH add ElasticNet (#230)
Co-authored-by: mathurinm <[email protected]>
1 parent a4263f3 commit 4230db1

13 files changed

Lines changed: 657 additions & 106 deletions

README.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ Currently, the package handles the following problems:
99

1010
- Lasso
1111
- Weighted Lasso
12+
- ElasticNet
13+
- Weighted ElasticNet
1214
- Sparse Logistic regression
1315
- Weighted Group Lasso
1416
- Multitask Lasso

celer/PN_logreg.pyx

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ from libc.math cimport fabs, sqrt, exp
1313
from sklearn.exceptions import ConvergenceWarning
1414

1515
from .cython_utils cimport fdot, faxpy, fcopy, fposv, fscal, fnrm2
16-
from .cython_utils cimport (primal, dual, create_dual_pt, create_accel_pt,
17-
sigmoid, ST, LOGREG, dnorm_l1,
16+
from .cython_utils cimport (primal, dual, create_dual_pt,
17+
sigmoid, ST, LOGREG, dnorm_enet,
1818
compute_Xw, compute_norms_X_col, set_prios)
1919

2020
cdef:
@@ -35,6 +35,10 @@ def newton_celer(
3535
else:
3636
dtype = np.float32
3737

38+
# Enet not supported for Logreg
39+
cdef floating l1_ratio = 1.0
40+
cdef floating norm_w2 = 0.
41+
3842
cdef int verbose_in = max(0, verbose - 1)
3943
cdef int n_samples = y.shape[0]
4044
cdef int n_features = w.shape[0]
@@ -102,19 +106,19 @@ def newton_celer(
102106
cdef bint positive = 0
103107

104108
for t in range(max_iter):
105-
p_obj = primal(LOGREG, alpha, Xw, y, w, weights_pen)
109+
p_obj = primal(LOGREG, alpha, l1_ratio, Xw, y, w, weights_pen)
106110

107111
# theta = y * sigmoid(-y * Xw)
108112
create_dual_pt(LOGREG, n_samples, &theta[0], &Xw[0], &y[0])
109-
norm_Xtheta = dnorm_l1(
110-
is_sparse, theta, X, X_data, X_indices, X_indptr,
111-
screened, X_mean, weights_pen, center, positive)
113+
norm_Xtheta = dnorm_enet(
114+
is_sparse, theta, w, X, X_data, X_indices, X_indptr,
115+
screened, X_mean, weights_pen, center, positive, alpha, l1_ratio)
112116

113117
if norm_Xtheta > alpha:
114118
tmp = alpha / norm_Xtheta
115119
fscal(&n_samples, &tmp, &theta[0], &inc)
116120

117-
d_obj = dual(LOGREG, n_samples, 0., &theta[0], &y[0])
121+
d_obj = dual(LOGREG, n_samples, alpha, l1_ratio, 0., norm_w2, &theta[0], &y[0])
118122
gap = p_obj - d_obj
119123

120124
if t != 0 and use_accel:
@@ -165,15 +169,15 @@ def newton_celer(
165169
for i in range(n_samples):
166170
exp_Xw[i] = exp(Xw[i])
167171

168-
norm_Xtheta_acc = dnorm_l1(
169-
is_sparse, theta_acc, X, X_data, X_indices, X_indptr,
170-
screened, X_mean, weights_pen, center, positive)
172+
norm_Xtheta_acc = dnorm_enet(
173+
is_sparse, theta_acc, w, X, X_data, X_indices, X_indptr,
174+
screened, X_mean, weights_pen, center, positive, alpha, l1_ratio)
171175

172176
if norm_Xtheta_acc > alpha:
173177
tmp = alpha / norm_Xtheta_acc
174178
fscal(&n_samples, &tmp, &theta_acc[0], &inc)
175179

176-
d_obj_acc = dual(LOGREG, n_samples, 0., &theta_acc[0], &y[0])
180+
d_obj_acc = dual(LOGREG, n_samples, alpha, l1_ratio, 0., norm_w2, &theta_acc[0], &y[0])
177181
if d_obj_acc > d_obj:
178182
fcopy(&n_samples, &theta_acc[0], &inc, &theta[0], &inc)
179183
gap = p_obj - d_obj_acc
@@ -188,7 +192,7 @@ def newton_celer(
188192
break
189193

190194

191-
set_prios(is_sparse, theta, alpha, X, X_data, X_indices, X_indptr,
195+
set_prios(is_sparse, theta, w, alpha, l1_ratio, X, X_data, X_indices, X_indptr,
192196
norms_X_col, weights_pen, prios, screened, radius,
193197
&n_screened, 0)
194198

@@ -249,6 +253,10 @@ cpdef int PN_logreg(
249253
cdef int ws_size = WS.shape[0]
250254
cdef int n_features = w.shape[0]
251255

256+
# Enet not supported for Logreg
257+
cdef floating l1_ratio = 1.0
258+
cdef floating norm_w2 = 0.
259+
252260
if floating is double:
253261
dtype = np.float64
254262
else:
@@ -369,15 +377,15 @@ cpdef int PN_logreg(
369377

370378
else:
371379
# rescale aux to create dual point
372-
norm_Xaux = dnorm_l1(
373-
is_sparse, aux, X, X_data, X_indices, X_indptr,
374-
notin_WS, X_mean, weights_pen, center, 0)
380+
norm_Xaux = dnorm_enet(
381+
is_sparse, aux, w, X, X_data, X_indices, X_indptr,
382+
notin_WS, X_mean, weights_pen, center, 0, alpha, l1_ratio)
375383

376384
for i in range(n_samples):
377385
aux[i] /= max(1, norm_Xaux / alpha)
378386

379-
d_obj = dual(LOGREG, n_samples, 0, &aux[0], &y[0])
380-
p_obj = primal(LOGREG, alpha, Xw, y, w, weights_pen)
387+
d_obj = dual(LOGREG, n_samples, alpha, l1_ratio, 0., norm_w2, &aux[0], &y[0])
388+
p_obj = primal(LOGREG, alpha, l1_ratio, Xw, y, w, weights_pen)
381389

382390
gap = p_obj - d_obj
383391
if verbose_in:

celer/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
"""Celer algorithm to solve L1-type regularized problems."""
22

33
from .homotopy import celer_path
4-
from .dropin_sklearn import (Lasso, LassoCV, LogisticRegression, GroupLasso,
5-
GroupLassoCV, MultiTaskLasso, MultiTaskLassoCV)
4+
from .dropin_sklearn import (ElasticNet, ElasticNetCV,
5+
GroupLasso, GroupLassoCV,
6+
Lasso, LassoCV, LogisticRegression,
7+
MultiTaskLasso, MultiTaskLassoCV)
68

79

810
__version__ = '0.7dev'

celer/cython_utils.pxd

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@ cdef int LOGREG
88

99
cdef floating ST(floating, floating) nogil
1010

11-
cdef floating dual(int, int, floating, floating *, floating *) nogil
12-
cdef floating primal(int, floating, floating[:], floating [:],
11+
cdef floating fweighted_norm_w2(floating[:], floating[:]) nogil
12+
13+
cdef floating dual(int, int, floating, floating, floating, floating, floating *, floating *) nogil
14+
cdef floating primal(int, floating, floating, floating[:], floating [:],
1315
floating [:], floating[:]) nogil
1416
cdef void create_dual_pt(int, int, floating *, floating *, floating *) nogil
1517

@@ -27,7 +29,7 @@ cdef void fposv(char *, int *, int *, floating *,
2729
int *, floating *, int *, int *) nogil
2830

2931
cdef int create_accel_pt(
30-
int, int, int, int, floating, floating *, floating *,
32+
int, int, int, int, floating *, floating *,
3133
floating *, floating[:, :], floating[:, :], floating[:], floating[:])
3234

3335

@@ -42,11 +44,11 @@ cpdef void compute_norms_X_col(
4244
floating[:], int[:], int[:], floating[:])
4345

4446

45-
cpdef floating dnorm_l1(
46-
bint, floating[:], floating[::1, :], floating[:],
47-
int[:], int[:], int[:], floating[:], floating[:], bint, bint) nogil
47+
cpdef floating dnorm_enet(
48+
bint, floating[:], floating[:], floating[::1, :], floating[:],
49+
int[:], int[:], int[:], floating[:], floating[:], bint, bint, floating, floating) nogil
4850

4951

5052
cdef void set_prios(
51-
bint, floating[:], floating, floating[::1, :], floating[:], int[:],
53+
bint, floating[:], floating[:], floating, floating, floating[::1, :], floating[:], int[:],
5254
int[:], floating[:], floating[:], floating[:], int[:], floating, int *, bint) nogil

celer/cython_utils.pyx

Lines changed: 45 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,20 @@ cdef inline floating Nh(floating x) nogil:
109109
return INFINITY # not - INFINITY
110110

111111

112+
@cython.boundscheck(False)
113+
@cython.wraparound(False)
114+
cdef floating fweighted_norm_w2(floating[:] w, floating[:] weights) nogil:
115+
cdef floating weighted_norm = 0.
116+
cdef int n_features = w.shape[0]
117+
cdef int j
118+
119+
for j in range(n_features):
120+
if weights[j] == INFINITY:
121+
continue
122+
weighted_norm += weights[j] * w[j] ** 2
123+
return weighted_norm
124+
125+
112126
@cython.boundscheck(False)
113127
@cython.wraparound(False)
114128
@cython.cdivision(True)
@@ -141,7 +155,7 @@ cdef floating primal_logreg(
141155
@cython.wraparound(False)
142156
@cython.cdivision(True)
143157
cdef floating primal_lasso(
144-
floating alpha, floating[:] R, floating[:] w,
158+
floating alpha, floating l1_ratio, floating[:] R, floating[:] w,
145159
floating[:] weights) nogil:
146160
cdef int n_samples = R.shape[0]
147161
cdef int n_features = w.shape[0]
@@ -152,24 +166,27 @@ cdef floating primal_lasso(
152166
for j in range(n_features):
153167
# avoid nan when weights[j] is INFINITY
154168
if w[j]:
155-
p_obj += alpha * weights[j] * fabs(w[j])
169+
p_obj += alpha * weights[j] * (
170+
l1_ratio * fabs(w[j]) +
171+
0.5 * (1. - l1_ratio) * w[j] ** 2)
156172
return p_obj
157173

158174

159175
cdef floating primal(
160-
int pb, floating alpha, floating[:] R, floating[:] y,
176+
int pb, floating alpha, floating l1_ratio, floating[:] R, floating[:] y,
161177
floating[:] w, floating[:] weights) nogil:
162178
if pb == LASSO:
163-
return primal_lasso(alpha, R, w, weights)
179+
return primal_lasso(alpha, l1_ratio, R, w, weights)
164180
else:
165181
return primal_logreg(alpha, R, y, w, weights)
166182

167183

168184
@cython.boundscheck(False)
169185
@cython.wraparound(False)
170186
@cython.cdivision(True)
171-
cdef floating dual_lasso(int n_samples, floating norm_y2,
172-
floating * theta, floating * y) nogil:
187+
cdef floating dual_enet(int n_samples, floating alpha, floating l1_ratio,
188+
floating norm_y2, floating norm_w2, floating * theta,
189+
floating * y) nogil:
173190
"""Theta must be feasible"""
174191
cdef int i
175192
cdef floating d_obj = 0.
@@ -178,6 +195,8 @@ cdef floating dual_lasso(int n_samples, floating norm_y2,
178195
d_obj -= (y[i] - n_samples * theta[i]) ** 2
179196
d_obj *= 0.5 / n_samples
180197
d_obj += norm_y2 / (2. * n_samples)
198+
if l1_ratio != 1.0:
199+
d_obj -= 0.5 * alpha * (1 - l1_ratio) * norm_w2
181200
return d_obj
182201

183202

@@ -195,11 +214,10 @@ cdef floating dual_logreg(int n_samples, floating * theta,
195214
return d_obj
196215

197216

198-
cdef floating dual(int pb, int n_samples, floating norm_y2,
199-
floating * theta, floating * y) nogil:
200-
217+
cdef floating dual(int pb, int n_samples, floating alpha, floating l1_ratio,
218+
floating norm_y2, floating norm_w2, floating * theta, floating * y) nogil:
201219
if pb == LASSO:
202-
return dual_lasso(n_samples, norm_y2, &theta[0], &y[0])
220+
return dual_enet(n_samples, alpha, l1_ratio, norm_y2, norm_w2, &theta[0], &y[0])
203221
else:
204222
return dual_logreg(n_samples, &theta[0], &y[0])
205223

@@ -226,7 +244,6 @@ cdef void create_dual_pt(
226244
@cython.cdivision(True)
227245
cdef int create_accel_pt(
228246
int pb, int n_samples, int epoch, int gap_freq,
229-
floating alpha,
230247
floating * R, floating * out, floating * last_K_R, floating[:, :] U,
231248
floating[:, :] UtU, floating[:] onesK, floating[:] y):
232249

@@ -365,11 +382,11 @@ cpdef void compute_Xw(
365382
@cython.boundscheck(False)
366383
@cython.wraparound(False)
367384
@cython.cdivision(True)
368-
cpdef floating dnorm_l1(
369-
bint is_sparse, floating[:] theta, floating[::1, :] X,
385+
cpdef floating dnorm_enet(
386+
bint is_sparse, floating[:] theta, floating[:] w, floating[::1, :] X,
370387
floating[:] X_data, int[:] X_indices, int[:] X_indptr, int[:] skip,
371388
floating[:] X_mean, floating[:] weights, bint center,
372-
bint positive) nogil:
389+
bint positive, floating alpha, floating l1_ratio) nogil:
373390
"""compute norm(X[:, ~skip].T.dot(theta), ord=inf)"""
374391
cdef int n_samples = theta.shape[0]
375392
cdef int n_features = skip.shape[0]
@@ -399,6 +416,10 @@ cpdef floating dnorm_l1(
399416
else:
400417
Xj_theta = fdot(&n_samples, &theta[0], &inc, &X[0, j], &inc)
401418

419+
# minus sign to consider the choice theta = y - Xw and not theta = Xw -y
420+
if l1_ratio != 1:
421+
Xj_theta -= alpha * (1 - l1_ratio) * weights[j] * w[j]
422+
402423
if not positive:
403424
Xj_theta = fabs(Xj_theta)
404425
scal = max(scal, Xj_theta / weights[j])
@@ -409,14 +430,15 @@ cpdef floating dnorm_l1(
409430
@cython.wraparound(False)
410431
@cython.cdivision(True)
411432
cdef void set_prios(
412-
bint is_sparse, floating[:] theta, floating alpha,
433+
bint is_sparse, floating[:] theta, floating[:] w, floating alpha, floating l1_ratio,
413434
floating[::1, :] X, floating[:] X_data, int[:] X_indices, int[:] X_indptr,
414435
floating[:] norms_X_col, floating[:] weights, floating[:] prios,
415436
int[:] screened, floating radius, int * n_screened, bint positive) nogil:
416437
cdef int i, j, startptr, endptr
417438
cdef floating Xj_theta
418439
cdef int n_samples = theta.shape[0]
419440
cdef int n_features = prios.shape[0]
441+
cdef floating norms_X_col_j = 0.
420442

421443
# TODO we do not substract theta_sum, which seems to indicate that theta
422444
# is always centered...
@@ -433,11 +455,17 @@ cdef void set_prios(
433455
else:
434456
Xj_theta = fdot(&n_samples, &theta[0], &inc, &X[0, j], &inc)
435457

458+
norms_X_col_j = norms_X_col[j]
459+
if l1_ratio != 1:
460+
Xj_theta -= alpha * (1 - l1_ratio) * weights[j] * w[j]
461+
462+
norms_X_col_j = norms_X_col_j ** 2
463+
norms_X_col_j += sqrt(norms_X_col_j + alpha * (1 - l1_ratio) * weights[j])
436464

437465
if positive:
438-
prios[j] = fabs(Xj_theta - alpha * weights[j]) / norms_X_col[j]
466+
prios[j] = fabs(Xj_theta - alpha * l1_ratio * weights[j]) / norms_X_col_j
439467
else:
440-
prios[j] = (alpha * weights[j] - fabs(Xj_theta)) / norms_X_col[j]
468+
prios[j] = (alpha * l1_ratio * weights[j] - fabs(Xj_theta)) / norms_X_col_j
441469

442470
if prios[j] > radius:
443471
screened[j] = True

0 commit comments

Comments
 (0)