-
Notifications
You must be signed in to change notification settings - Fork 548
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Kernel ridge regression #4492
Kernel ridge regression #4492
Conversation
At 14800 rows, cuml takes 0.8071575206238777s, sklearn takes 10.78322389847599s, speedup is 13.359503718854398. import time
import numpy as np
import pandas as pd
from cuml import KernelRidge as cuKernelRidge
from sklearn.kernel_ridge import KernelRidge as sklKernelRidge
from sklearn.metrics import mean_squared_error as mse
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm
sns.set()
rows_all = np.arange(100, 15000, 300)
# rows_all = np.arange(100, 500, 300)
cols_all = [100]
iterations = 5
rs = np.random.RandomState(2)
estimators = {"sklearn": sklKernelRidge(), "cuml": cuKernelRidge()}
df = pd.DataFrame()
use_cache = False
if not use_cache:
for n_rows in tqdm(rows_all):
for n_cols in cols_all:
X = rs.normal(size=(n_rows, n_cols))
y = rs.normal(size=n_rows)
for name, alg in estimators.items():
# warmup
alg.fit(X[0:10], y[0:10])
for i in range(iterations):
start = time.perf_counter()
alg.fit(X, y)
pred = alg.predict(X)
time_taken = time.perf_counter() - start
if "cupy" in str(type(pred)):
pred = pred.get()
df = df.append(
{
"Algorithm": name,
"n_rows": n_rows,
"n_cols": n_cols,
"MSE": mse(y, pred),
"Time": time_taken,
"Iteration": i,
},
ignore_index=True,
)
if use_cache:
df = pd.read_pickle("kernel_rr.pkl")
else:
df.to_pickle("kernel_rr.pkl")
int_cols = ["n_rows", "n_cols", "Iteration"]
df[int_cols] = df[int_cols].astype(int)
sns.lineplot(x="n_rows", y="Time", hue="Algorithm", data=df)
plt.yscale("log")
plt.xticks(rotation=45)
plt.title(
"Kernel ridge regression time (linear kernel, {} features, float64)".format(
cols_all[-1]
)
)
plt.savefig("kernel_ridge_time.png")
plt.clf()
sns.barplot(x="n_rows", y="MSE", hue="Algorithm", data=df)
plt.xticks(rotation=45)
plt.title(
"Kernel ridge regression MSE (linear kernel, {} features, float64)".format(
cols_all[-1]
)
)
plt.savefig("kernel_ridge_mse.png")
sklearn_largest_time = df[
(df["n_rows"] == df["n_rows"].max()) & (df["Algorithm"] == "sklearn")
]["Time"].mean()
cuml_largest_time = df[
(df["n_rows"] == df["n_rows"].max()) & (df["Algorithm"] == "cuml")
]["Time"].mean()
print(
"At {} rows, cuml takes {}s, sklearn takes {}s, speedup is {}.".format(
df["n_rows"].max(), cuml_largest_time, sklearn_largest_time, sklearn_largest_time/cuml_largest_time
)
) |
baba26f
to
c3bd2c4
Compare
This should be on the 22.04 board, not 22.02. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm happy to see more kernel-based methods in cuml. This is a really nice port from scikit-learn and I'm thinking the new API for building custom kernels might even be useful for pairwise distances in general (maybe w/ an option to turn symmetry on and off).
pairwise_kernels(X, Y, metric='linear') | ||
|
||
@cuda.jit(device=True) | ||
def custom_rbf_kernel(x, y, gamma=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I very much like the ability to quickly build custom kernels. Have you done any profiling / benchmarking of this against the cuml.metrics.pairwise_distances
API? I'm mostly curious to know the gap between the two, and whether there's a perf hit for the different memory access patterns.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be interesting to see the difference, but for now it doesn't really matter as the matrix inversion dominates computation time. It could be 5 times slower than the cuda version and we won't see any real difference in end to end time.
The bigger disadvantage of this approach for me has been jit compile time. It's in the range of a few hundred ms, which I think is reasonable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No worries. My question isn't about this algorithm in particular. It's been on our todo list for quite awhile to see how performant it would be to allow users to implement custom pairwise distance measures in Numba.
return (X, Y) | ||
|
||
|
||
@given(kernel_arg_strategy(), array_strategy()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Love the use of hypothesis here. I'm hoping we will start using it more in cuml.
Have addressed review comments. I changed the kernel implementations to build off existing primitives more, this had a couple of side effects. The jit compilation overhead went away for most of the kernels, taking the overall test time from 30s down to 10s. The estimator also became much less accurate for float32 inputs, because before I was able to force intermediate calculations to double precision. Accordingly, the tolerance has been significantly reduced for float32 tests. The cosine kernel still uses the custom kernel path, as implementing this the sklearn way is just very inaccurate and caused me to fail some tests. Chi^2 kernels also still use the custom kernel path as I can't immediately see how to use existing primatives to get this. I might benchmark the custom versions against the newer versions later if I get time, but this is more a matter of curiosity. |
Benchmarks comparing custom kernel performance against Implementation using primitives. The custom kernels implementation falls off considerably at higher dimension due to poor memory access patterns. It is still faster than sklearn. import cupy as cp
import numpy as np
from numba import cuda
from cuml.metrics import pairwise_kernels
from sklearn.metrics.pairwise import pairwise_kernels as skl_pairwise_kernels
import math
import time
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
sns.set()
df = pd.DataFrame()
for col in tqdm(range(10, 110, 10)):
rs = np.random.RandomState(259)
X = rs.normal(size=(20000, col))
X_device = cp.array(X)
K = pairwise_kernels(X_device[0:10], metric='rbf')
start = time.perf_counter()
K = pairwise_kernels(X_device, metric='rbf')
cp.cuda.runtime.deviceSynchronize()
standard_time = time.perf_counter()-start
df = df.append(
{"Algorithm": 'rbf', "n_rows": X.shape[0], "n_cols": X.shape[1], "Time": standard_time}, ignore_index=True)
@cuda.jit(device=True)
def custom_rbf_kernel(x, y, gamma=None):
if gamma is None:
gamma = 1.0 / len(x)
sum = 0.0
for i in range(len(x)):
sum += (x[i] - y[i]) ** 2
return math.exp(-gamma * sum)
start = time.perf_counter()
K = skl_pairwise_kernels(X, metric='rbf')
cp.cuda.runtime.deviceSynchronize()
skl_time = time.perf_counter()-start
df = df.append(
{"Algorithm": 'rbf_skl', "n_rows": X.shape[0], "n_cols": X.shape[1], "Time": skl_time}, ignore_index=True)
# warmup
K = pairwise_kernels(X_device[0:10], metric=custom_rbf_kernel)
start = time.perf_counter()
K = pairwise_kernels(X_device, metric=custom_rbf_kernel)
cp.cuda.runtime.deviceSynchronize()
custom_time = time.perf_counter()-start
df = df.append({"Algorithm": 'rbf_custom',
"n_rows": X.shape[0], "n_cols": X.shape[1], "Time": custom_time}, ignore_index=True)
print(df)
sns.lineplot(x='n_cols', y='Time', hue='Algorithm', data=df)
plt.yscale('log')
plt.title('Pairwise kernel time 20,000 rows, varying cols')
plt.savefig("custom_kernels.png") |
@@ -0,0 +1,291 @@ | |||
# | |||
# Copyright (c) 2019-2022, NVIDIA CORPORATION. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just noticed this- we should remove 2019 since this is a new file.
z += x[i]*y[i] | ||
x_norm += x[i] * x[i] | ||
y_norm += y[i] * y[i] | ||
return z / math.sqrt(x_norm * y_norm) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this is how the pairwise_distances
are computing the cosine as well (with exception that it's the 2 - [a.dot(b) / (sqrt(x_l2_norm) * sqrt(y_l2_norm)]
(and the sqrt(a)sqrt(b) = sqrt(ab)). It looks like you are doing this as well. Are you saying there's a numerical issue that might be causing incorrect values?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The sklearn version here (https://github.com/scikit-learn/scikit-learn/blob/9f85c9d44965b764f40169ef2917e5f7a798684f/sklearn/metrics/pairwise.py#L1265), when ported using cupy and using cumls normalize function, seemed to be numerically unstable to me. This is why I kept the custom kernel version. I can look more into it if necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm just wondering if correcting the cosine distance from cuml.metric.pairwise_distances
back to a similarity might help eliminate the jit overhead from this one as well. If not, we can always look further into it in the future. Thanks for changing the other ones!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I see. I have changed that one to use cosine distance too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changes LGTM
rerun tests |
1 similar comment
rerun tests |
Codecov Report
@@ Coverage Diff @@
## branch-22.04 #4492 +/- ##
===============================================
Coverage ? 85.74%
===============================================
Files ? 239
Lines ? 19588
Branches ? 0
===============================================
Hits ? 16796
Misses ? 2792
Partials ? 0
Flags with carried forward coverage won't be shown. Click here to find out more. Continue to review full report at Codecov.
|
@gpucibot merge |
Sklearn reference implementation: https://github.com/scikit-learn/scikit-learn/blob/7e1e6d09b/sklearn/kernel_ridge.py#L16 I've tried to avoid touching the c++/cuda layer so far. Pairwise kernels are implented based on a numba kernel for now. I've also used cupy's lapack wrapper to access cuSolver. The implementation of `pairwise_kernels` here can be reused to very easily implement kernel PCA. Todo: - [x] Single target fit/predict - [x] Standard kernels implemented - [x] Support custom kernels - [x] Support sample weights - [ ] ~~Support CSR X matrix. Maybe too difficult for this PR.~~ - [x] Multi-target fit/predict - [x] Change .py files to .pyx and moved to correct places. - [x] Benchmarking on reasonably large files - [x] Tests take less than 20s - [x] Ensure correct handling of input/output array types (I think I need to be using CumlArray and maybe some decorators) - [x] Documentation Authors: - Rory Mitchell (https://github.com/RAMitchell) Approvers: - Corey J. Nolet (https://github.com/cjnolet) - Micka (https://github.com/lowener) URL: rapidsai#4492
Sklearn reference implementation: https://github.com/scikit-learn/scikit-learn/blob/7e1e6d09b/sklearn/kernel_ridge.py#L16
I've tried to avoid touching the c++/cuda layer so far. Pairwise kernels are implented based on a numba kernel for now. I've also used cupy's lapack wrapper to access cuSolver.
The implementation of
pairwise_kernels
here can be reused to very easily implement kernel PCA.Todo:
Support CSR X matrix. Maybe too difficult for this PR.