Last active
January 15, 2021 06:00
-
-
Save wyattowalsh/3bfb1a924007f19a7191a17b6c4e52a0 to your computer and use it in GitHub Desktop.
Implementation of the Elastic Net for Regression in Python Using NumPy
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def elastic_net(X, y, l, alpha, tol=1e-4, path_length=100, return_path=False): | |
"""The Elastic Net Regression model with intercept term. | |
Intercept term included via design matrix augmentation. | |
Pathwise coordinate descent with co-variance updates is applied. | |
Path from max value of the L1 tuning parameter to input tuning parameter value. | |
Features must be standardized (centered and scaled to unit variance) | |
Params: | |
X - NumPy matrix, size (N, p), of standardized numerical predictors | |
y - NumPy array, length N, of numerical response | |
l - l penalty tuning parameter (positive scalar) | |
alpha - alpha penalty tuning parameter (positive scalar between 0 and 1) | |
tol - Coordinate Descent convergence tolerance (exited if change < tol) | |
path_length - Number of tuning parameter values to include in path (positive integer) | |
Returns: | |
NumPy array, length p + 1, of fitted model coefficients | |
""" | |
X = np.hstack((np.ones((len(X), 1)), X)) | |
m, n = np.shape(X) | |
B_star = np.zeros((n)) | |
if alpha == 0: | |
l2 = 1e-15 | |
l_max = max(list(abs(np.dot(np.transpose(X), y)))) / m / alpha | |
if l >= l_max: | |
return np.append(np.mean(y), np.zeros((n - 1))) | |
l_path = np.geomspace(l_max, l, path_length) | |
for i in range(path_length): | |
while True: | |
B_s = B_star | |
for j in range(n): | |
k = np.where(B_s != 0)[0] | |
update = (1/m)*((np.dot(X[:,j], y)- \ | |
np.dot(np.dot(X[:,j], X[:,k]), B_s[k]))) + \ | |
B_s[j] | |
B_star[j] = (np.sign(update) * max( | |
abs(update) - l_path[i] * alpha, 0)) / (1 + (l_path[i] * (1 - alpha))) | |
if np.all(abs(B_s - B_star) < tol): | |
break | |
return B_star |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment