Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

device=cuda_exp is slower than device=cuda on lightgbm.cv #5693

Closed
ninist opened this issue Jan 31, 2023 · 5 comments
Closed

device=cuda_exp is slower than device=cuda on lightgbm.cv #5693

ninist opened this issue Jan 31, 2023 · 5 comments

Comments

@ninist
Copy link

ninist commented Jan 31, 2023

Edit 2023-02-09

I tried a simplified case without using RFECV in a followup-comment, and the issue is reproducible just using lightgbm.cv.

Description

I built two different versions of lightgbm - the first with cuda_exp and the second with cuda.

I do feature selection with sklearn.RFECV.

I instantiate a lightgbm.LGBMRegressor, and N times I call
RFECV(model, n_jobs=None, ...), each time with a slightly different subset of the training data (simply selecting a subset of 95% of the data each call).

The idea behind performing several (e.g. 25) RFECV-runs is to eliminate variability in the returned selected features.

The result is a collections.Counter object that counts how many times each feature was selected out of the N runs.

The issue is that device="cuda_exp" is much slower than device="cuda".

Specifically, if I import the module compiled with cuda_exp, both device="cuda_exp" and device="gpu" are much slower than
what they are if I import the module compiled with the older cuda.

Reproducible example

See code at the end of the post

Environment info

LightGBM version or commit hash:

Command(s) you used to install LightGBM

# Clone / copy two instances of the repository, one for cuda and one for cuda_exp.
# cuda
cd build
cmake -DUSE_GPU=1 -DUSE_CUDA=1 -DUSE_CUDA_EXP=0 -DOpenCL_LIBRARY=/usr/local/cuda-12.0/lib64/libOpenCL.so -DOpenCL_INCLUDE_DIR=/usr/local/cuda/include/ ..
make -j8
cd ../python-package/
python3 setup.py install --precompile
# move the base directory of the installed package to the new name `lightgbm_cuda`
# (This is highly nonstandard, though appears to serve the purpose of simplifying comparisons of the two libraries here).
# cuda_exp
cd build
cmake -DUSE_GPU=1 -DUSE_CUDA=0 -DUSE_CUDA_EXP=1 -DOpenCL_LIBRARY=/usr/local/cuda-12.0/lib64/libOpenCL.so -DOpenCL_INCLUDE_DIR=/usr/local/cuda/include/ ..
make -j8
cd ../python-package/
python3 setup.py install --precompile
# move the base directory of the installed package to the new name `lightgbm_cuda_exp`
# (This is highly nonstandard, though appears to serve the purpose of simplifying comparisons of the two libraries here).

commit: 9954bc4

  • sklearn-1.2.0
  • pandas-1.5.2
  • numpy-1.21.5
  • python-3.9.7

Hardware:

  • AMD Ryzen 7 5800H
  • NVIDIA GeForce RTX 3060 Laptop GPU [6 GB vram]
  • 16 GB of memory

Code to run a particular model

Please note that the dataset below is 100% bogus. I unfortunately cannot share the real dataset. I made a lazy attempt to make the example-dataset below have the same number of features and rows and approximate range/collection of values as the real dataset.

To run a model, pass in one of the options: cuda, cuda_gpu, cuda_exp, cuda_exp_gpu, cpu.

The invocation will produce a logfile with the time it took to run it. I have included a log file from my own invocations below.

from sklearn.feature_selection import RFECV
from sklearn.model_selection import train_test_split

from datetime import datetime, timedelta
import sys
import time
import numpy as np
import pandas as pd

from collections import Counter

if sys.argv[1] == "cuda":
    import lightgbm_cuda as lightgbm
    DEVICE="cuda"
    print(f"Using lightgbm_cuda with DEVICE={DEVICE}")

elif sys.argv[1] == "cuda_gpu":
    import lightgbm_cuda as lightgbm
    DEVICE="gpu"
    print(f"Using lightgbm_cuda with DEVICE={DEVICE}")

elif sys.argv[1] == "cpu":
    import lightgbm_cuda as lightgbm
    DEVICE="cpu"
    print(f"Using lightgbm_cuda with DEVICE={DEVICE}")

elif sys.argv[1] == "cuda_exp":
    import lightgbm_cuda_exp as lightgbm
    DEVICE="cuda_exp"
    print(f"Using lightgbm_cuda_exp with DEVICE={DEVICE}")

elif sys.argv[1] == "cuda_exp_gpu":
    import lightgbm_cuda_exp as lightgbm
    DEVICE="gpu"
    print(f"Using lightgbm_cuda_exp with DEVICE={DEVICE}")

#elif sys.argv[1] == "cuda_exp_cpu":
#    import lightgbm_cuda_exp as lightgbm
#    DEVICE="cpu"
#    print(f"Using lightgbm_cuda_exp with DEVICE={DEVICE}")

else:
    raise Exception("arg should be one of: cuda, cuda_gpu, cuda_exp, cuda_exp_gpu, cpu")

ARG = sys.argv[1]

np.random.seed(1)
N_continuous = 110
N_indicators = 65

ansi_underline = "\033[96m"
ansi_green = "\033[92m"
ansi_end = "\033[0m"

def get_data():
    """construct a demo dataset with approximately the same shape and format"""
    index = pd.date_range(
        datetime(2019, 6, 1),
        datetime(2022, 6, 1),
        freq="H",
    )

    columns = []
    for i in range(N_continuous):
        if i%2 == 0:
            col = np.random.uniform(low=0, high=1000, size=len(index))
        else:
            col = np.random.normal(0, 1, size=len(index))
        columns.append(col.ravel())

    for i in range(N_indicators):
        col = np.random.choice([0, 1], size=len(index))
        columns.append(col.ravel())

    df = pd.concat([pd.Series(x) for x in columns], axis=1)
    target = pd.Series(np.random.uniform(low=0, high=2000, size=len(index)))

    return df.astype(np.float32), target.astype(np.float32)


def feature_select_step(model, train_X, train_y):
    """run one instance of feature selection on one training set"""
    selector = RFECV(model, n_jobs=None, step=1, verbose=False, cv=2)
    selector.fit(
        train_X,
        train_y.ravel(),
    )
    cols = train_X.loc[:, selector.support_].columns
    return cols

def run_aggregated_feature_select(train_X, train_y):
    """run feature selection successively and aggregate the resulting columns in a Counter"""
    model = lightgbm.LGBMRegressor(
        random_state=1,
        #n_estimators=24,
        #num_leaves=16,
        objective="regression_l1",
        metrics="l2", # cuda_exp raises warnings with l1
        n_estimators=50,
        num_leaves=64,
        max_bin=63,
        device=DEVICE,
        #gpu_use_dp=True,
        gpu_platform_id=0,
        gpu_device_id=0,
        num_thread=28,
        verbose=0,
    )
    agg = []
    times = []

    print(ansi_underline + "\nRunning feature select for model: " + repr(model) + ansi_end)

    # how many feature selections to run (only do one in the minimal example/demo)
    N = 1
    t0 = time.time()
    for random_state in range(10, 10+N):
        print(ansi_green + f"train_test_split(random state={random_state})" + ansi_end)
        # arbitrarily exclude 5% of the rows to introduce some variability
        # these excluded rows are not used at all for the current iteration
        tts_train_X, _____, tts_train_y, _____ = train_test_split(train_X, train_y, test_size=0.05, random_state=random_state)
        cols = feature_select_step(model, tts_train_X, tts_train_y)
        agg.append(cols)
        tn = time.time()
        print("TIME:", tn-t0)
        times.append(tn-t0)
        t0 = tn
    C = Counter()
    for index in agg:
        C.update(index)

    with open("lightgbm_perf_test.log", "a") as f:
        f.write(f"{ARG}[{DEVICE}]: {times}, {model}\n\n")

    return C

if __name__ == '__main__':
    train_X, train_y = get_data()
    run_aggregated_feature_select(train_X, train_y)

Timing results

Below, each record is one execution of the above program. The identifier ARG[DEVICE] serves to show which library was imported, and which device was passed. The number after it is the time in seconds, and following the time is a textual representation of the model that was fitted.

For cpu the library is the older cuda (no runs were performed with device="cpu" on the cuda_exp-library).

cuda[cuda]: [88.90170121192932], LGBMRegressor(device='cuda', gpu_device_id=0, gpu_platform_id=0, max_bin=63,
              metrics='l2', n_estimators=50, num_leaves=64, num_thread=28,
              objective='regression_l1', random_state=1, verbose=0)

cuda_gpu[gpu]: [124.41583967208862], LGBMRegressor(device='gpu', gpu_device_id=0, gpu_platform_id=0, max_bin=63,
              metrics='l2', n_estimators=50, num_leaves=64, num_thread=28,
              objective='regression_l1', random_state=1, verbose=0)

cuda_exp[cuda_exp]: [331.1839654445648], LGBMRegressor(device='cuda_exp', gpu_device_id=0, gpu_platform_id=0, max_bin=63,
              metrics='l2', n_estimators=50, num_leaves=64, num_thread=28,
              objective='regression_l1', random_state=1, verbose=0)

cuda_exp_gpu[gpu]: [172.52361226081848], LGBMRegressor(device='gpu', gpu_device_id=0, gpu_platform_id=0, max_bin=63,
              metrics='l2', n_estimators=50, num_leaves=64, num_thread=28,
              objective='regression_l1', random_state=1, verbose=0)

cpu[cpu]: [77.27314162254333], LGBMRegressor(device='cpu', gpu_device_id=0, gpu_platform_id=0, max_bin=63,
              metrics='l2', n_estimators=50, num_leaves=64, num_thread=28,
              objective='regression_l1', random_state=1, verbose=0)

# Perform three more runs to replicate the results above
cuda_exp_gpu[gpu]: [169.93835926055908], LGBMRegressor(device='gpu', gpu_device_id=0, gpu_platform_id=0, max_bin=63,
              metrics='l2', n_estimators=50, num_leaves=64, num_thread=28,
              objective='regression_l1', random_state=1, verbose=0)

cuda_gpu[gpu]: [122.16611576080322], LGBMRegressor(device='gpu', gpu_device_id=0, gpu_platform_id=0, max_bin=63,
              metrics='l2', n_estimators=50, num_leaves=64, num_thread=28,
              objective='regression_l1', random_state=1, verbose=0)

cuda[cuda]: [89.40578246116638], LGBMRegressor(device='cuda', gpu_device_id=0, gpu_platform_id=0, max_bin=63,
              metrics='l2', n_estimators=50, num_leaves=64, num_thread=28,
              objective='regression_l1', random_state=1, verbose=0)

I note that CPU is faster in this case, though cuda_exp is almost 4x as slow as cuda.

Is this caused by cuda_exp having higher overhead than cuda?

I tried some tweaking of the suggested options like using double precision, changing metric from l1 to l2 to perform compute on the gpu, tweaking max bins, and trying to get rid of sparseness-warnings thrown by some models:

[LightGBM] [Warning] CUDA currently requires double precision calculations.
[LightGBM] [Warning] Using sparse features with CUDA is currently not supported.

I had no luck with improving the performance of cuda_exp doing this.

Lastly, I would be inclined to agree that RFECV may interact poorly with CUDA/GPU-computing.

@ninist
Copy link
Author

ninist commented Feb 9, 2023

I was able to reproduce this for just plain lightgbm.cv without sklearn.

With device="cuda_exp", the lightgbm.cv call takes 1.9s while with device="cuda" it takes 0.9s

This is on commit 9954bc4 just before cuda_exp was made standard and cuda removed.

from datetime import datetime
import sys
import time
import numpy as np
import pandas as pd

if sys.argv[1] == "cuda":
    import lightgbm_cuda as lightgbm
    DEVICE="cuda"
    print(f"Using lightgbm_cuda with DEVICE={DEVICE}")

elif sys.argv[1] == "cuda_gpu":
    import lightgbm_cuda as lightgbm
    DEVICE="gpu"
    print(f"Using lightgbm_cuda with DEVICE={DEVICE}")

elif sys.argv[1] == "cpu":
    import lightgbm_cuda as lightgbm
    DEVICE="cpu"
    print(f"Using lightgbm_cuda with DEVICE={DEVICE}")

elif sys.argv[1] == "cpu_exp":
    import lightgbm_cuda_exp as lightgbm
    DEVICE="cpu"
    print(f"Using lightgbm_cuda_Exp with DEVICE={DEVICE}")

elif sys.argv[1] == "cuda_exp":
    import lightgbm_cuda_exp as lightgbm
    DEVICE="cuda_exp"
    print(f"Using lightgbm_cuda_exp with DEVICE={DEVICE}")

elif sys.argv[1] == "cuda_exp_gpu":
    import lightgbm_cuda_exp as lightgbm
    DEVICE="gpu"
    print(f"Using lightgbm_cuda_exp with DEVICE={DEVICE}")

else:
    raise Exception("arg should be one of: cuda, cuda_gpu, cpu, cpu_exp, cuda_exp, cuda_exp_gpu")

ARG = sys.argv[1]

np.random.seed(1)
N_continuous = 110
N_indicators = 65

ansi_underline = "\033[96m"
ansi_green = "\033[92m"
ansi_end = "\033[0m"

def get_data():
    """construct a demo dataset with approximately the same shape and format"""
    index = pd.date_range(
        datetime(2019, 6, 1),
        datetime(2022, 10, 20),
        freq="H",
    )

    columns = []
    for i in range(N_continuous):
        if i%2 == 0:
            col = np.random.uniform(low=0, high=1000, size=len(index))
        else:
            col = np.random.normal(0, 1, size=len(index))
        columns.append(col.ravel())

    for i in range(N_indicators):
        col = np.random.choice([0, 1], size=len(index))
        columns.append(col.ravel())

    df = pd.concat([pd.Series(x) for x in columns], axis=1)
    target = pd.Series(np.random.uniform(low=0, high=2000, size=len(index)))

    df = df.astype(np.float32)
    df.index = index
    target = target.astype(np.float32)
    target.index = df.index

    dataset = lightgbm.Dataset(df, target)
    return dataset

def run_lightgbm(dataset):
    """run lightgbm.cv once"""

    print(ansi_underline + "\nRunning feature select for model: " + ansi_end)

    params = {
        'random_state': 1,
        #n_estimators=24,
        #num_leaves=16,
        'objective': 'regression_l1',
        'metrics': 'regression_l2', # cuda_exp raises warnings with l1 as opposed to l2
        #'n_estimators': 96, # num boosting rounds
        'num_leaves': 64,
        'max_depth': 15, # XXX
        'max_bin': 63, # XXX
        'device': DEVICE, # gpu cpu cuda cuda_exp
        #'gpu_use_dp': True,
        #gpu_platform_id=0,
        #gpu_device_id=0,
        'verbose': 0,
    }
    if DEVICE in ['cuda', 'cuda_exp']:
        params['is_enable_sparse'] = False


    t0 = time.time()
    model = lightgbm.cv(
        params=params,
        train_set=dataset,
        num_boost_round=99999999,
        shuffle=False,
        stratified=False, # regression does not permit stratified
        callbacks=[lightgbm.early_stopping(20, verbose=False), lightgbm.log_evaluation(1)],
        eval_train_metric=False,
        seed=2,
    )
    tn = time.time()
    print("TIME:", tn-t0)
    time_taken = tn-t0

    with open("lightgbm_cv_perf_test.log", "a") as f:
        f.write(f"{ARG}[{DEVICE}]: {time_taken}, {model}\n\n")

if __name__ == '__main__':
    dataset = get_data()
    run_lightgbm(dataset)

@ninist ninist changed the title sklearn.RFECV: cuda_exp is slower than cuda cuda_exp is slower than cuda on lightgbm.cv (with or without sklearn.RFECV) Feb 9, 2023
@ninist ninist changed the title cuda_exp is slower than cuda on lightgbm.cv (with or without sklearn.RFECV) device=cuda_exp is slower than device=cuda on lightgbm.cv (with or without sklearn.RFECV) Feb 9, 2023
@ninist ninist changed the title device=cuda_exp is slower than device=cuda on lightgbm.cv (with or without sklearn.RFECV) device=cuda_exp is slower than device=cuda on lightgbm.cv Feb 9, 2023
@shiyu1994
Copy link
Collaborator

@ninist Thanks for the detailed benchmarking. I think cross validation requires more rounds of data loading. Currently the data loading part for cuda_exp is not carefully optimized. And that maybe the cause why overall time is much slower.

Is the training time of cuda_exp still slower than cuda without CV?

@ninist
Copy link
Author

ninist commented Feb 10, 2023

Is the training time of cuda_exp still slower than cuda without CV?

Yes, still the same outcome of cuda_exp being slower with just lightgbm.train.

    model = lightgbm.train(
        params=params,
        train_set=dataset_train,
        num_boost_round=99999999,
        callbacks=[lightgbm.early_stopping(20, verbose=False), lightgbm.log_evaluation(1)],
        valid_sets=dataset_val,
    )
    train_X = df[df.index <= val_date]
    train_y = target[target.index <= val_date]
    val_X = df[df.index >= val_date]
    val_y = target[target.index >= val_date]

    dataset_params = {"feature_pre_filter": False}
    dataset_train = lightgbm.Dataset(train_X, train_y, params=dataset_params)
    dataset_val = lightgbm.Dataset(val_X, val_y, params=dataset_params)
175 features

>>> dataset_train.num_data()
20929
>>> dataset_val.num_data()
8761

cuda 0.275s vs cuda_exp 0.615s
(cpu with 8 threads is slightly faster than cuda)

Larger training data

I tried changing the start year of the range from 2019 to 1800 to see if more training data changes the outcome (1940641 rows now) -- still the same outcome.

cuda 5.356s vs cuda_exp 20.255s
(cpu with 8 threads is now 1-2s slower than cuda but much faster than cuda_exp)

@jameslamb
Copy link
Collaborator

Sorry for the long delay in response.

The implementation that used to be called cuda has been removed, and the one that used to be called cuda_exp is now called cuda. LightGBM now only has a single CUDA implementation: #5677.

That implementation has also received significant improvements in the 14 months since there was last activity on this discussion, including:

Could you please try again, and hopefully with a smaller reproducible example?

I'm marking this awaiting response, which means it'll automatically be closed in 30 days. If you notice other performance issues with the CUDA implementation here in the future, please open new issues with minimal, reproducible examples, and we'll try to help.

Copy link

This issue has been automatically closed because it has been awaiting a response for too long. When you have time to to work with the maintainers to resolve this issue, please post a new comment and it will be re-opened. If the issue has been locked for editing by the time you return to it, please open a new issue and reference this one. Thank you for taking the time to improve LightGBM!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants