Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OSCP]利用SPU实现分位数回归算法 #865

Merged
merged 5 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions sml/linear_model/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,11 @@ py_binary(
"//sml/linear_model/utils:solver",
],
)

py_library(
name = "quantile",
srcs = ["quantile.py"],
deps = [
"//sml/linear_model/utils:_linprog_simplex",
],
)
9 changes: 9 additions & 0 deletions sml/linear_model/emulations/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,12 @@ py_binary(
"//sml/utils:emulation",
],
)

py_binary(
name = "quantile_emul",
srcs = ["quantile_emul.py"],
deps = [
"//sml/linear_model:quantile",
"//sml/utils:emulation",
],
)
108 changes: 108 additions & 0 deletions sml/linear_model/emulations/quantile_emul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright 2024 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time

import jax.numpy as jnp
from sklearn.linear_model import QuantileRegressor as SklearnQuantileRegressor

import sml.utils.emulation as emulation
from sml.linear_model.quantile import QuantileRegressor as SmlQuantileRegressor

CONFIG_FILE = emulation.CLUSTER_ABY3_3PC


def emul_quantile(mode=emulation.Mode.MULTIPROCESS):
def proc_wrapper(
quantile,
alpha,
fit_intercept,
lr,
max_iter,
):
quantile_custom = SmlQuantileRegressor(
quantile=quantile,
alpha=alpha,
fit_intercept=fit_intercept,
lr=lr,
max_iter=max_iter,
)

def proc(X, y):
quantile_custom_fit = quantile_custom.fit(X, y)
result = quantile_custom_fit.predict(X)
return result, quantile_custom_fit.coef_, quantile_custom_fit.intercept_

return proc

def generate_data():
from jax import random

# 设置随机种子
key = random.PRNGKey(42)
# 生成 X 数据
key, subkey = random.split(key)
X = random.normal(subkey, (100, 2))
# 生成 y 数据
y = (
5 * X[:, 0] + 2 * X[:, 1] + random.normal(key, (100,)) * 0.1
) # 高相关性,带有小噪声
return X, y

try:
# bandwidth and latency only work for docker mode
emulator = emulation.Emulator(CONFIG_FILE, mode, bandwidth=300, latency=20)
emulator.up()

# load mock data
X, y = generate_data()

# compare with sklearn
quantile_sklearn = SklearnQuantileRegressor(
quantile=0.2, alpha=0.1, fit_intercept=True, solver='highs'
)
start = time.time()
quantile_sklearn_fit = quantile_sklearn.fit(X, y)
y_pred_plain = quantile_sklearn_fit.predict(X)
rmse_plain = jnp.sqrt(jnp.mean((y - y_pred_plain) ** 2))
end = time.time()
print(f"Running time in SKlearn: {end - start:.2f}s")
print(quantile_sklearn_fit.coef_)
print(quantile_sklearn_fit.intercept_)

# mark these data to be protected in SPU
X_spu, y_spu = emulator.seal(X, y)

# run
# Larger max_iter can give higher accuracy, but it will take more time to run
proc = proc_wrapper(
quantile=0.2, alpha=0.1, fit_intercept=True, lr=0.01, max_iter=200
)
start = time.time()
result, coef, intercept = emulator.run(proc)(X_spu, y_spu)
end = time.time()
rmse_encrpted = jnp.sqrt(jnp.mean((y - result) ** 2))
print(f"Running time in SPU: {end - start:.2f}s")
print(coef)
print(intercept)

# print RMSE
print(f"RMSE in SKlearn: {rmse_plain:.2f}")
print(f"RMSE in SPU: {rmse_encrpted:.2f}")

finally:
emulator.down()


if __name__ == "__main__":
emul_quantile(emulation.Mode.MULTIPROCESS)
196 changes: 196 additions & 0 deletions sml/linear_model/quantile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
# Copyright 2024 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import jax
import jax.numpy as jnp
import pandas as pd
from jax import grad

from sml.linear_model.utils._linprog_simplex import _linprog_simplex


class QuantileRegressor:
"""
Initialize the quantile regression model.
Parameters
----------
quantile : float, default=0.5
The quantile to be predicted. Must be between 0 and 1.
A quantile of 0.5 corresponds to the median (50th percentile).
alpha : float, default=1.0
Regularization strength; must be a positive float.
Larger values specify stronger regularization, reducing model complexity.
fit_intercept : bool, default=True
Whether to calculate the intercept for the model.
If False, no intercept will be used in calculations, meaning the model will
assume that the data is already centered.
lr : float, default=0.01
Learning rate for the optimization process. This controls the size of
the steps taken in each iteration towards minimizing the objective function.
max_iter : int, default=1000
The maximum number of iterations for the optimization algorithm.
This controls how long the model will continue to update the weights
before stopping.
max_val : float, default=1e10
The maximum value allowed for the model parameters.
Attributes
----------
coef_ : array-like of shape (n_features,)
The coefficients (weights) assigned to the input features. These will be
learned during model fitting.
intercept_ : float
The intercept (bias) term. If `fit_intercept=True`, this will be
learned during model fitting.
"""

def __init__(
self,
quantile=0.5,
alpha=1.0,
fit_intercept=True,
lr=0.01,
max_iter=1000,
max_val=1e10,
):
self.quantile = quantile
self.alpha = alpha
self.fit_intercept = fit_intercept
self.lr = lr
self.max_iter = max_iter
self.max_val = max_val

self.coef_ = None
self.intercept_ = None

def fit(self, X, y, sample_weight=None):
"""
Fit the quantile regression model using linear programming.
Parameters
----------
X : array-like of shape (n_samples, n_features)
Training data.
y : array-like of shape (n_samples,)
Target values.
sample_weight : array-like of shape (n_samples,), optional
Individual weights for each sample. If not provided, all samples
are assumed to have equal weight.
Returns
-------
self : object
Returns an instance of self.
Steps:
1. Determine the number of parameters (`n_params`), accounting for the intercept if needed.
2. Define the objective function `c`, incorporating both the L1 regularization and the pinball loss.
3. Set up the equality constraint matrix `A_eq` and vector `b_eq` based on the input data `X` and `y`.
4. Solve the linear programming problem using `_linprog_simplex`.
5. Extract the model parameters (intercept and coefficients) from the solution.
"""
n_samples, n_features = X.shape
n_params = n_features

if sample_weight is None:
sample_weight = jnp.ones((n_samples,))

if self.fit_intercept:
n_params += 1

alpha = jnp.sum(sample_weight) * self.alpha

# After rescaling alpha, the minimization problem is
# min sum(pinball loss) + alpha * L1
# Use linear programming formulation of quantile regression
# min_x c x
# A_eq x = b_eq
# 0 <= x
# x = (s0, s, t0, t, u, v) = slack variables >= 0
# intercept = s0 - t0
# coef = s - t
# c = (0, alpha * 1_p, 0, alpha * 1_p, quantile * 1_n, (1-quantile) * 1_n)
# residual = y - X@coef - intercept = u - v
# A_eq = (1_n, X, -1_n, -X, diag(1_n), -diag(1_n))
# b_eq = y
# p = n_features
# n = n_samples
# 1_n = vector of length n with entries equal one
# see https://stats.stackexchange.com/questions/384909/
c = jnp.concatenate(
[
jnp.full(2 * n_params, fill_value=alpha),
sample_weight * self.quantile,
sample_weight * (1 - self.quantile),
]
)

if self.fit_intercept:
c = c.at[0].set(0)
c = c.at[n_params].set(0)

eye = jnp.eye(n_samples)
if self.fit_intercept:
ones = jnp.ones((n_samples, 1))
A = jnp.concatenate([ones, X, -ones, -X, eye, -eye], axis=1)
else:
A = jnp.concatenate([X, -X, eye, -eye], axis=1)

b = y

result = _linprog_simplex(
c, A, b, maxiter=self.max_iter, tol=1e-3, max_val=self.max_val
)

solution = result

params = solution[:n_params] - solution[n_params : 2 * n_params]

if self.fit_intercept:
self.coef_ = params[1:]
self.intercept_ = params[0]
else:
self.coef_ = params
self.intercept_ = 0.0
return self

def predict(self, X):
"""
Predict target values using the fitted quantile regression model.
Parameters
----------
X : array-like of shape (n_samples, n_features)
Input data for which predictions are to be made.
Returns
-------
y_pred : array-like of shape (n_samples,)
Predicted target values.
Notes
-----
The predict method computes the predicted target values using the model's
learned coefficients and intercept (if fit_intercept=True).
- If the model includes an intercept, a column of ones is added to the input data `X` to account
for the intercept in the linear combination.
- The method then computes the dot product between the modified `X` and the stacked vector of
intercept and coefficients.
- If there is no intercept, the method simply computes the dot product between `X` and the coefficients.
"""

xbw886 marked this conversation as resolved.
Show resolved Hide resolved
assert (
self.coef_ is not None and self.intercept_ is not None
), "Model has not been fitted yet. Please fit the model before predicting."

n_features = len(self.coef_)
assert X.shape[1] == n_features, (
f"Input X must have {n_features} features, "
f"but got {X.shape[1]} features instead."
)

return jnp.dot(X, self.coef_) + self.intercept_
10 changes: 10 additions & 0 deletions sml/linear_model/tests/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,13 @@ py_test(
"//spu/utils:simulation",
],
)

py_test(
name = "quantile_test",
srcs = ["quantile_test.py"],
deps = [
"//sml/linear_model:quantile",
"//spu:init",
"//spu/utils:simulation",
],
)
Loading
Loading