Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Benchmark] Reuse optimum-benchmark #30615

Merged
merged 2 commits into from
May 21, 2024
Merged

[Benchmark] Reuse optimum-benchmark #30615

merged 2 commits into from
May 21, 2024

Conversation

ydshieh
Copy link
Collaborator

@ydshieh ydshieh commented May 2, 2024

What does this PR do?

[Benchmark] Reuse Otpimum

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ydshieh ydshieh force-pushed the benchmark_reuse_otpimum branch 3 times, most recently from e3be7ec to 151f0db Compare May 6, 2024 14:10
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the default config, but we can override the values from the cli.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can discuss some values defined here.

Comment on lines +232 to +230
metrics = [
"prefill.latency.mean",
"prefill.throughput.value",
"decode.latency.mean",
"decode.throughput.value",
"per_token.latency.mean",
"per_token.throughput.value",
]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be defined directly in this script, as this is task (experiment) dependent.
(So far we are using config/generation.yaml)

Comment on lines +257 to +250
elif len(commits) == 1 and commits[0] == "diff":
# compare to `main`
commits = ["main", current_head]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we are!

Comment on lines 108 to 116
self.inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)
)

@torch.no_grad()
def forward(self, x, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if self.inv_freq is None:
self.inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
)
self.inv_freq.to(device=x.device)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fix is necessary to make torch.compile work with this model. I will revert the change in this PR, but we should check what bets fix to do in another PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is another wrapper to have the control over the (transformers) commit and some manual report (post) processing and display.

@ydshieh ydshieh marked this pull request as ready for review May 6, 2024 14:56
@ydshieh ydshieh requested a review from ArthurZucker May 6, 2024 14:56
Makefile Outdated
Comment on lines 99 to 103
# Run benchmark

benchmark:
python3 benchmark/benchmark.py --config-dir benchmark/config --config-name generation --commit=diff backend.model=google/gemma-2b --multirun

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

--commit=diff : this will run the benchmark on main and current branch or head.
(it won't work at this moment - only after PR being merged to main will work)

backend.model=google/gemma-2b: This is temporarily (for demonstration purpose)

The long-term would be having it like backend.model=$(cat models_to_run.txt) where models_to_run.txt would be prepared by a CI job step.

And if no such file, it will run the default models (defined in benchmark/benchmark.py or )

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should include a workflow that pushes the result of running this on main to a dataset, and compare to them next.
But this is good for now!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! Yes, step by step 🚀

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is nothing but just a python wrapper of the optimum-benchmark cli. We don't really need this file and a simple change in benchmark/benchmark.py could do the job.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A naive run gave this:

arthur@brahms ~/transformers (benchmark_reuse_otpimum)> make benchmark                                                                                                                                                                               (py39) 
python3 benchmark/benchmark.py --config-dir benchmark/config --config-name generation --commit=diff backend.model=google/gemma-2b --multirun
Run benchmark on commit: df53c6e5d9245315c741ba6cce1e026d4ca104c5
2024-05-09 10:44:30.258357: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2024-05-09 10:44:30.297135: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-05-09 10:44:30.874467: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
/home/arthur/miniconda3/envs/py39/lib/python3.9/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: '/home/arthur/miniconda3/envs/py39/lib/python3.9/site-packages/torchvision/image.so: undefined symbol: _ZN3c1017RegisterOperatorsD1Ev'If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?
  warn(
Additional config directory '/home/arthur/transformers/benchmark/config' not found

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.
Run benchmark on commit: 195e1adf82eccc994e54412de4638dc3b5f7ee6a
2024-05-09 10:44:36.455346: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2024-05-09 10:44:36.495380: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-05-09 10:44:37.091402: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
/home/arthur/miniconda3/envs/py39/lib/python3.9/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: '/home/arthur/miniconda3/envs/py39/lib/python3.9/site-packages/torchvision/image.so: undefined symbol: _ZN3c1017RegisterOperatorsD1Ev'If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?
  warn(
In 'hydra/config': Could not find 'hydra/job_logging/colorlog'

Available options in 'hydra/job_logging':
        default
        disabled
        none
        stdout
Config search path:
        provider=hydra, path=pkg://hydra.conf
        provider=main, path=pkg://optimum_benchmark
        provider=command-line, path=file:///home/arthur/transformers/benchmark/config
        provider=schema, path=structured://

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.
{}

pip install hydra_colorlog is needed.
Make bench could install all the dependencies ?

One thing that is needed now is to just have a small post-processing that tells you if you are slower or not!

Also there are too many logs from optimum benchmark, we should de-activate all of them by default!

python3 benchmark/benchmark.py --config-dir benchmark/config --config-name generation --commit=diff backend.model=google/gemma-2b --multirun
Run benchmark on commit: df53c6e5d9245315c741ba6cce1e026d4ca104c5
2024-05-09 11:04:39.749127: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2024-05-09 11:04:39.789786: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-05-09 11:04:40.398935: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
/home/arthur/miniconda3/envs/py39/lib/python3.9/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: '/home/arthur/miniconda3/envs/py39/lib/python3.9/site-packages/torchvision/image.so: undefined symbol: _ZN3c1017RegisterOperatorsD1Ev'If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?
  warn(
Additional config directory '/home/arthur/transformers/benchmark/config' not found

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.
Run benchmark on commit: 195e1adf82eccc994e54412de4638dc3b5f7ee6a
2024-05-09 11:04:45.890283: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2024-05-09 11:04:45.929361: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-05-09 11:04:46.515069: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
/home/arthur/miniconda3/envs/py39/lib/python3.9/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: '/home/arthur/miniconda3/envs/py39/lib/python3.9/site-packages/torchvision/image.so: undefined symbol: _ZN3c1017RegisterOperatorsD1Ev'If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?
  warn(
[2024-05-09 11:04:48,404][HYDRA] Launching 1 jobs locally
[2024-05-09 11:04:48,404][HYDRA]        #0 : backend.model=google/gemma-2b
[MAIN-PROCESS][2024-05-09 11:04:48,778][hydra-cli][WARNING] - The `benchmark: inference` in your defaults list is deprecated. Please use `scenario: inference` instead.
[MAIN-PROCESS][2024-05-09 11:04:48,779][experiment][WARNING] - The `experiment` parent schema is deprecated and will be removed soon. Please use `benchmark` parent schema instead. You'll also need to change the `experiment_name` field to `name` and `benchmark` schema to `scenario` schema. See the repository README for more information.
[MAIN-PROCESS][2024-05-09 11:04:48,781][launcher][INFO] - Allocating process launcher
[MAIN-PROCESS][2024-05-09 11:04:48,781][process][INFO] -        + Setting multiprocessing start method to spawn.
[MAIN-PROCESS][2024-05-09 11:04:48,841][device-isolation][INFO] - Started device(s) isolation process [1629448], monitoring the isolated process [1629447], running on device(s) [0], with action [warn].
2024-05-09 11:04:52.189211: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
2024-05-09 11:04:52.200587: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
/home/arthur/miniconda3/envs/py39/lib/python3.9/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: '/home/arthur/miniconda3/envs/py39/lib/python3.9/site-packages/torchvision/image.so: undefined symbol: _ZN3c1017RegisterOperatorsD1Ev'If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?
  warn(
/home/arthur/miniconda3/envs/py39/lib/python3.9/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: '/home/arthur/miniconda3/envs/py39/lib/python3.9/site-packages/torchvision/image.so: undefined symbol: _ZN3c1017RegisterOperatorsD1Ev'If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?
  warn(
[ISOLATED-PROCESS][2024-05-09 11:04:53,827][process][INFO] - Running benchmark in isolated process [1629447].
[ISOLATED-PROCESS][2024-05-09 11:04:54,408][backend][INFO] - Allocating pytorch backend
[ISOLATED-PROCESS][2024-05-09 11:04:54,409][backend][INFO] -    + Setting random seed to 42
/home/arthur/miniconda3/envs/py39/lib/python3.9/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
[ISOLATED-PROCESS][2024-05-09 11:04:55,446][pytorch][INFO] -    + Using AutoModel AutoModelForCausalLM
[ISOLATED-PROCESS][2024-05-09 11:04:55,446][pytorch][INFO] -    + Creating backend temporary directory
[ISOLATED-PROCESS][2024-05-09 11:04:55,446][pytorch][INFO] -    + Loading model with random weights
[ISOLATED-PROCESS][2024-05-09 11:04:55,446][pytorch][INFO] -    + Creating no weights model directory
[ISOLATED-PROCESS][2024-05-09 11:04:55,446][pytorch][INFO] -    + Creating no weights model state dict
[ISOLATED-PROCESS][2024-05-09 11:04:55,447][pytorch][INFO] -    + Saving no weights model safetensors
[ISOLATED-PROCESS][2024-05-09 11:04:55,448][pytorch][INFO] -    + Saving no weights model pretrained config
[ISOLATED-PROCESS][2024-05-09 11:04:55,449][pytorch][INFO] -    + Loading Transformers model using device context manager for fast initialization
[ISOLATED-PROCESS][2024-05-09 11:04:55,848][pytorch][INFO] -    + Setting cache implementation to static
[ISOLATED-PROCESS][2024-05-09 11:04:55,849][pytorch][INFO] -    + Turning on model's eval mode
[ISOLATED-PROCESS][2024-05-09 11:04:55,849][pytorch][INFO] -    + Using torch.compile on forward
[ISOLATED-PROCESS][2024-05-09 11:04:55,852][scenario][INFO] - Allocating inference scenario
[ISOLATED-PROCESS][2024-05-09 11:04:55,853][inference][INFO] -  + Creating input generator
[ISOLATED-PROCESS][2024-05-09 11:04:55,853][input][INFO] -      + Using text-generation task generator
[ISOLATED-PROCESS][2024-05-09 11:04:55,853][inference][INFO] -  + Generating Text Generation inputs
[ISOLATED-PROCESS][2024-05-09 11:04:55,853][inference][INFO] -  + Preparing Text Generation inputs
[ISOLATED-PROCESS][2024-05-09 11:04:55,853][inference][INFO] -  + Updating Text Generation kwargs with default values
[ISOLATED-PROCESS][2024-05-09 11:04:55,853][inference][INFO] -  + Initializing Text Generation report
[ISOLATED-PROCESS][2024-05-09 11:04:55,854][inference][INFO] -  + Preparing backend for Inference
[ISOLATED-PROCESS][2024-05-09 11:04:55,854][inference][INFO] -  + Warming up backend for Text Generation
[ISOLATED-PROCESS][2024-05-09 11:05:42,775][inference][INFO] -  + Running Text Generation memory tracking
[ISOLATED-PROCESS][2024-05-09 11:05:42,776][memory][INFO] -     + Tracking RAM memory of process with PID [1629447]
[ISOLATED-PROCESS][2024-05-09 11:05:42,776][memory][INFO] -     + Tracking VRAM memory of CUDA devices with IDs [0]
[ISOLATED-PROCESS][2024-05-09 11:05:42,776][memory][INFO] -     + Tracking Allocated/Reserved memory of 1 Pytorch CUDA devices
2024-05-09 11:05:46.138887: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
/home/arthur/miniconda3/envs/py39/lib/python3.9/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: '/home/arthur/miniconda3/envs/py39/lib/python3.9/site-packages/torchvision/image.so: undefined symbol: _ZN3c1017RegisterOperatorsD1Ev'If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?
  warn(
2024-05-09 11:05:51.057017: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
/home/arthur/miniconda3/envs/py39/lib/python3.9/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: '/home/arthur/miniconda3/envs/py39/lib/python3.9/site-packages/torchvision/image.so: undefined symbol: _ZN3c1017RegisterOperatorsD1Ev'If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?
  warn(
2024-05-09 11:05:56.086866: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
/home/arthur/miniconda3/envs/py39/lib/python3.9/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: '/home/arthur/miniconda3/envs/py39/lib/python3.9/site-packages/torchvision/image.so: undefined symbol: _ZN3c1017RegisterOperatorsD1Ev'If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?
  warn(
2024-05-09 11:06:01.013892: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
/home/arthur/miniconda3/envs/py39/lib/python3.9/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: '/home/arthur/miniconda3/envs/py39/lib/python3.9/site-packages/torchvision/image.so: undefined symbol: _ZN3c1017RegisterOperatorsD1Ev'If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?
  warn(

benchmark/benchmark.py Show resolved Hide resolved
Comment on lines +62 to +76
Each summary's format is as follows (for `expand_metrics=False`):
```
{
"model": "google/gemma-2b",
"commit": "3cd6ed22e4d49219f300f5055e71e3929aba20d7",
"config": "benchmark.input_shapes.batch_size=1,benchmark.input_shapes.sequence_length=5",
"metrics": {
"decode.latency.mean": 1.624666809082031,
"per_token.latency.mean": 0.012843788806628804,
"per_token.throughput.value": 77.85864553330948
}
}
```
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice !

benchmark/config/generation.yaml Show resolved Hide resolved
Makefile Outdated
Comment on lines 99 to 103
# Run benchmark

benchmark:
python3 benchmark/benchmark.py --config-dir benchmark/config --config-name generation --commit=diff backend.model=google/gemma-2b --multirun

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should include a workflow that pushes the result of running this on main to a dataset, and compare to them next.
But this is good for now!

benchmark/benchmark.py Show resolved Hide resolved
benchmark/benchmark.py Show resolved Hide resolved
benchmark/benchmark.py Show resolved Hide resolved
Comment on lines 7 to 8
- override hydra/job_logging: colorlog # colorful logging
- override hydra/hydra_logging: colorlog # colorful logging
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this need a special version of hydra?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am copying this from optimum-benchmark 😓 . No idea actually. We can discuss with @IlyasMoutawwakil

Copy link
Member

@IlyasMoutawwakil IlyasMoutawwakil May 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ydshieh I reduced the dependency on hydra plugins, you don't need to set hydra_colorlog anymore, check the examples in https://github.com/huggingface/optimum-benchmark/blob/main/examples/pytorch_bert.yaml they're cleaner 🤗

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool˜

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@IlyasMoutawwakil Thanks. Currently I still need to pin optimum-benchmark to commit 995bf7b035fc6927c673d45f10a7b0b6784b2650, and some updates have to be done to match the recent changes you made.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there's a need, the change is very minimal (most of it in the config).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(well, I once installed the dev version and could run anymore, but probably it work now, haven't tried yet)

Makefile Outdated
# Run benchmark

benchmark:
python3 benchmark/benchmark.py --config-dir benchmark/config --config-name generation --commit=diff backend.model=google/gemma-2b --multirun
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should detect the devices to run / run on all devices with a || ?
Also we should run n this make bench static, dynamic and etc

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should detect the devices to run / run on all devices with a || ?

At this moment, let's focus on the single GPU device which works well. We can iterate to other cases if we find necessary.

Also we should run n this make bench static, dynamic and etc
OK will do it

@ArthurZucker ArthurZucker mentioned this pull request May 10, 2024
@ydshieh
Copy link
Collaborator Author

ydshieh commented May 13, 2024

A naive run gave this:

arthur@brahms ~/transformers (benchmark_reuse_otpimum)> make benchmark                                                                                                                                                                               (py39) 
python3 benchmark/benchmark.py --config-dir benchmark/config --config-name generation 

@ArthurZucker

You are using --commit=diff but the main branch doesn't contain the new files added in this PR. It will only work only if this PR being merged. You can try however with 2 commit sha for now.

@ArthurZucker
Copy link
Collaborator

No @ydshieh the naive run error is In 'hydra/config': Could not find 'hydra/job_logging/colorlog'

@ArthurZucker
Copy link
Collaborator

Basically we need a extra["benchmark"] to install all the packages for us!

@ydshieh ydshieh requested a review from ArthurZucker May 13, 2024 09:10
@ydshieh ydshieh changed the title [Benchmark] Reuse Otpimum [Benchmark] Reuse optimum-benchmark May 13, 2024
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have much more to say, we need an extra dep, and make benchmark could install what you need to run the benchmarks. You can split the comparison and pushing to a dataset in another PR, but we need to get rid of all the logs and warnings.

@ydshieh
Copy link
Collaborator Author

ydshieh commented May 13, 2024

No @ydshieh the naive run error is In 'hydra/config': Could not find 'hydra/job_logging/colorlog'

OK, I only saw

Additional config directory '/home/arthur/transformers/benchmark/config' not found

sorry.

Anyway, @IlyasMoutawwakil already reduced the dependency and I will remove it.

@ydshieh
Copy link
Collaborator Author

ydshieh commented May 13, 2024

but we need to get rid of all the logs and warnings

what do you mean by this? You don't want to see any logs from optimum-benchmark?

@IlyasMoutawwakil
Copy link
Member

@ydshieh i think he meant the deprecation warnings of using the experiment schema instead of benchmark schema.
currently the cli supports both and only warns of their deprecation and how to migrate from one to another.

@ydshieh
Copy link
Collaborator Author

ydshieh commented May 15, 2024

Hi @IlyasMoutawwakil

Currently, the way to install optimum-benchmark is like

pip install optimum-benchmark@git+https://github.com/huggingface/optimum-benchmark.git

Do you think there is a way to add optimum-benchmark to transformers's setup.py file?
Otherwise, we can probably add requirments.txt in transformers's benchmark/.

Also, @ArthurZucker would like to disable the many logs from optimum-benchmark (more precisely, those [INFO]).
I could fine relevant functions like

from optimum_benchmark.logging_utils import setup_logging
setup_logging(level="INFO", handlers=["console"])

However, since I am running optimum-benchmark commandline through a python subprocess, do you think there is a way to disable the log in this case? (maybe it's already controllable via the config file?)

@IlyasMoutawwakil
Copy link
Member

@ydshieh I can do a pypi to make it easier to add to setup.py 🤗
for logging, I think a LOG_LEVEL env var makes sense in this case (can be set in hydra.job.env_set).

@ydshieh
Copy link
Collaborator Author

ydshieh commented May 15, 2024

hydra.job.env_set

Thanks I will give it a try (the logging part).

For installation, I think I will have to update this PR to match the optimum-benchmark recent changes, but it should be somehow minimal I guess. I will wait a pypi from your side the finalize the PR .

@ydshieh
Copy link
Collaborator Author

ydshieh commented May 16, 2024

Hi @IlyasMoutawwakil I am trying to run with the latest commit 9b28308362702169630ba269f677c11d3d7924b0 (optimum-benchmark). The goal is to match up the new format. However, there is some failure, see below.

Let's work together to have a working commit (if the issue is from there), use it for pypi maybe?

Traceback (most recent call last):
  File "/usr/local/bin/optimum-benchmark", line 5, in <module>
    from optimum_benchmark.cli import main
  File "/transformers/optimum-benchmark/optimum_benchmark/__init__.py", line 13, in <module>
    from .base import Benchmark
  File "/transformers/optimum-benchmark/optimum_benchmark/base.py", line 12, in <module>
    from .launchers.base import Launcher
  File "/transformers/optimum-benchmark/optimum_benchmark/launchers/base.py", line 10, in <module>
    from .device_isolation_utils import assert_device_isolation
  File "/transformers/optimum-benchmark/optimum_benchmark/launchers/device_isolation_utils.py", line 8, in <module>
    from ..logging_utils import setup_logging
  File "/transformers/optimum-benchmark/optimum_benchmark/logging_utils.py", line 50, in <module>
    def run_subprocess_and_log_stream_output(logger: logging.Logger, args: list[str]) -> Popen:
TypeError: 'type' object is not subscriptable
{}

@ydshieh
Copy link
Collaborator Author

ydshieh commented May 16, 2024

FYI: currently the working version (with the current PR status) is based on 995bf7b035fc6927c673d45f10a7b0b6784b2650 April 30

benchmark/config/generation.yaml Outdated Show resolved Hide resolved
benchmark/config/generation.yaml Outdated Show resolved Hide resolved
benchmark/config/generation.yaml Outdated Show resolved Hide resolved
benchmark/config/generation.yaml Outdated Show resolved Hide resolved
benchmark/config/generation.yaml Outdated Show resolved Hide resolved
Comment on lines 86 to 95
# Get the report
report_file = os.path.join(report_dir, "benchmark_report.json")
with open(report_file) as fp:
report = json.load(fp)

# Get the full experiment config
config_file = os.path.join(report_dir, "experiment_config.json")
with open(config_file) as fp:
config = json.load(fp)
model = config["backend"]["model"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better use the BenchmarkConfig.from_json and BenchmarkReport.from_json (or Benchmark.from_json that loads benchmark.json, in a dataclass that contains both config and report as discussed before)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, completely forgot that conversation 😅

benchmark/config/generation.yaml Outdated Show resolved Hide resolved
benchmark/benchmark.py Outdated Show resolved Hide resolved
benchmark/benchmark.py Outdated Show resolved Hide resolved
benchmark/benchmark.py Outdated Show resolved Hide resolved
@ydshieh ydshieh force-pushed the benchmark_reuse_otpimum branch from ae6b806 to a86c3f6 Compare May 16, 2024 09:45
@ydshieh ydshieh force-pushed the benchmark_reuse_otpimum branch from 611cb1c to 7b0edfd Compare May 16, 2024 11:25
@ydshieh
Copy link
Collaborator Author

ydshieh commented May 16, 2024

@ArthurZucker

I am still waiting for a pypi, but other comments are all addressed.

  • no more logs
  • run with static/dynamic cache + compile/not compile
    etc.

@IlyasMoutawwakil
Copy link
Member

@ydshieh https://pypi.org/project/optimum-benchmark/ 🤗

@ydshieh
Copy link
Collaborator Author

ydshieh commented May 16, 2024

updated!

@ydshieh ydshieh requested a review from ArthurZucker May 16, 2024 12:05
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. THere are a lot of TODO but it' s alright for a first PR.
Let's try to split the code, and talk about all the rest together!

Comment on lines +101 to +102
benchmark:
python3 benchmark/benchmark.py --config-dir benchmark/config --config-name generation --commit=diff backend.model=google/gemma-2b backend.cache_implementation=null,static backend.torch_compile=false,true --multirun
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this install the lib as well?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So far this Makefile doesn't install stuff in the commands. It's the users' to install beforehand.
Putting an installation here will (try to) install in each run which is not ideal.

I will add a comment.

@@ -42,6 +42,7 @@
"onnxruntime-tools": "onnxruntime-tools>=1.4.2",
"onnxruntime": "onnxruntime>=1.4.0",
"opencv-python": "opencv-python",
"optimum-benchmark": "optimum-benchmark>=0.2.0",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool

# set environment variable OVERRIDE_BENCHMARKS to 1
# to not skip benchmarks that have been run before
OVERRIDE_BENCHMARKS: 1
LOG_LEVEL: WARN
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this file in particular there is a lot of boiler plate, would be nice to split in more smaller functions instead of just having a comment for it to be more readable, and easier to update / change!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, those boiler plate is more about preparing the commandline arguments that would be sent to optimum-benchmark. Let's see what what could be improved along the time, but it's not the priority at this moment.

I agree with you however.

@ydshieh
Copy link
Collaborator Author

ydshieh commented May 21, 2024

Going to merge. Once done, I will update the example usages (about the commit numbers).

Thanks for the reviews and help (@IlyasMoutawwakil )

@ydshieh ydshieh merged commit 64e0573 into main May 21, 2024
22 checks passed
@ydshieh ydshieh deleted the benchmark_reuse_otpimum branch May 21, 2024 13:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants