Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add trt support for BF16 #195

Merged
merged 550 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
550 commits
Select commit Hold shift + click to select a range
eb4e7df
fix interface of `get_sample_input`
andompesta Oct 2, 2024
bf1cca6
save configuration parameters
andompesta Oct 2, 2024
358c8a5
ae wrapper implemented
andompesta Oct 2, 2024
381267d
fix import
andompesta Oct 2, 2024
a8af1d8
add AEWrapper step
andompesta Oct 2, 2024
a47608c
from set_model_to_dtype to prepare_model
andompesta Oct 3, 2024
ea420c5
fix eval mode during inference
andompesta Oct 3, 2024
af2f48b
fix clip onnx export. Now it trace ony the needed outputs
andompesta Oct 3, 2024
e6b66bb
fix t5 wrapper
andompesta Oct 3, 2024
cb188d8
reorder input name flux
andompesta Oct 3, 2024
54002de
fix flux input format for text_ids and guidance
andompesta Oct 4, 2024
1cdc0a8
fix Flux imports and scale of inputs to prevent nan
andompesta Oct 4, 2024
c1c3a8d
add torch inference while tracing
andompesta Oct 4, 2024
bb0cc66
fix casting problem in onnx trace
andompesta Oct 4, 2024
21ec7d9
solve optimization problem by removing cleanup steps
andompesta Oct 4, 2024
d6f5e2f
rename to notes
andompesta Oct 4, 2024
577ba49
prevent nan due to large inputs
andompesta Oct 4, 2024
dfd06fc
provide base implementation of `get_model`
andompesta Oct 4, 2024
54b2ceb
format
andompesta Oct 4, 2024
d7ccef4
add trt export step
andompesta Oct 6, 2024
505411b
add engine class for trt build
andompesta Oct 6, 2024
0154232
add `get_input_profile` and `get_minmax_dims` abstract methods
andompesta Oct 6, 2024
cc2d921
add `build_strongly_typed` attributed
andompesta Oct 6, 2024
a2fb731
implement `get_minmax_dims` and `get_input_profile`
andompesta Oct 6, 2024
0096f7a
remove `static_shape` from `get_sample_input`
andompesta Oct 6, 2024
dfb6ded
remove static sharpe and batch flags
andompesta Oct 7, 2024
50100b5
add typing
andompesta Oct 7, 2024
ea240be
remove static shape and batch flags
andompesta Oct 7, 2024
0c3720c
offload to cpu
andompesta Oct 7, 2024
30f0140
enable device offloading while tracing
andompesta Oct 7, 2024
f2b357a
check cuda is avaiable while building engines
andompesta Oct 7, 2024
a30ec20
clip trt engine build
andompesta Oct 7, 2024
dbeeed9
add pinned transformer dependency
andompesta Oct 9, 2024
0682915
fix nan with onnx and trt when executed on CUDA
andompesta Oct 9, 2024
bef25e0
AE need to be traced in TF32 not FP16
andompesta Oct 9, 2024
c028d8d
add `get_shape_dict` abstract method and device as a property
andompesta Oct 9, 2024
8208e4c
AE should be traced in TF32
andompesta Oct 9, 2024
816ff12
AE explicitly on TF32 and reactivate full pipeline
andompesta Oct 9, 2024
3a341f8
add input provile to flux to enable trt engine build
andompesta Oct 9, 2024
7aa6956
format and add input_profile to t5 for TRT build
andompesta Oct 9, 2024
e68a993
add `TransformersModelWrapper`
andompesta Oct 9, 2024
ea581b7
add TransformersModelWrapper support
andompesta Oct 9, 2024
7e883d5
add `get_shape_dict` interface
andompesta Oct 9, 2024
5080d86
add TransformersModelWrapper support
andompesta Oct 9, 2024
e2b65c4
add shape_dict interface
andompesta Oct 9, 2024
87413e2
t5 in TF32 for numerical reasons
andompesta Oct 11, 2024
8629e50
remove unused options
andompesta Oct 11, 2024
5e711c7
remove unused code
andompesta Oct 11, 2024
02235dc
add `get_shape_dict`
andompesta Oct 11, 2024
6c3c4db
remove custom optimization
andompesta Oct 11, 2024
4b8a973
add garbage collector
andompesta Oct 14, 2024
8e4b103
return error
andompesta Oct 14, 2024
8f45f81
create wrapper specific to Onnx export operatio
andompesta Oct 14, 2024
3af1a33
user OnnxWrapper
andompesta Oct 14, 2024
fe024b8
create base wrapper for trt engines
andompesta Oct 14, 2024
68060bd
moved to engine package
andompesta Oct 14, 2024
0f8d8b3
moved to engine package
andompesta Oct 14, 2024
49dc6d1
forbit relative import of trt-builder
andompesta Oct 14, 2024
098391b
remove wrapper and create BaseExporter or BaseEngine
andompesta Oct 14, 2024
bf9c4cb
models not stored in builder class
andompesta Oct 14, 2024
0ee9104
_prepare_model_configs as pure function
andompesta Oct 14, 2024
c7136f8
_get_onnx_exporters as a private method to get onnx exporters
andompesta Oct 14, 2024
ee72695
remove unused dependencies
andompesta Oct 14, 2024
ecf6c4f
from onnxwrapper to onnxengine
andompesta Oct 14, 2024
2a14000
trt engine class
andompesta Oct 14, 2024
c791c53
add `calculate_max_device_memory` to TRTBuilder
andompesta Oct 14, 2024
ce343dc
`get_shape_dict` moved to trt-engine interface
andompesta Oct 14, 2024
66ca1ce
add common inference code
andompesta Oct 14, 2024
7400072
autoencder inference wrapper
andompesta Oct 14, 2024
aa0d474
add requirements.txt
Oct 16, 2024
d676a18
support guidance for ev model
andompesta Oct 16, 2024
550f660
ad support for trt based on evn variables
andompesta Oct 16, 2024
fa5993b
format flux
andompesta Oct 16, 2024
bdbbb19
remove stream from constructor
andompesta Oct 16, 2024
f1d86f6
fix iterate over onnx-exporters
andompesta Oct 16, 2024
f065b09
flux is not strongly type
andompesta Oct 16, 2024
c57410a
move back for numerical stability
andompesta Oct 16, 2024
69f4dca
add logging
andompesta Oct 16, 2024
cc12a14
fix dtype casting for bfloat16
andompesta Oct 17, 2024
961259e
fix default value
andompesta Oct 17, 2024
6e1ca02
add version before merge
andompesta Oct 18, 2024
7217a7b
hacky get it building the engines
ducktrA Oct 15, 2024
c5481a1
requirements.txt
ducktrA Oct 17, 2024
54674c3
adding a seperate _engine.py file for all the flux, t5 and clip engine
ducktrA Oct 18, 2024
37003c7
boilerroom and plating. getting parameters handle into setting up the…
ducktrA Oct 18, 2024
fd33eb5
remove _version.py from git
andompesta Oct 18, 2024
99e72e9
create base mixin class to share parameters
andompesta Oct 18, 2024
6678a3b
clipmixin parameters
andompesta Oct 18, 2024
395541d
remove parameters as are part of mixin class
andompesta Oct 18, 2024
315dd9d
clip engine and exporter use common mixin for managing parameters
andompesta Oct 18, 2024
7cdbb03
use mixin cass to build engine from exporter
andompesta Oct 18, 2024
55497eb
ae-mixin for shared parameters
andompesta Oct 18, 2024
5917f38
flux exporter and engine unified by mixin class
andompesta Oct 21, 2024
7c156cd
formatting
andompesta Oct 21, 2024
92f13f8
add common `get_latent_dims` method
andompesta Oct 21, 2024
f5acd54
add `get_latent_dims` common method
andompesta Oct 21, 2024
8b182cc
T5 based on mixin class
andompesta Oct 21, 2024
11570dc
build strongly typed flux
andompesta Oct 21, 2024
a9acfa0
enable load with shared device memory
andompesta Oct 21, 2024
c6e94a6
remove boilderpart code to create engines
andompesta Oct 21, 2024
7b07602
add tokenizer to trt engine
andompesta Oct 22, 2024
2dc2460
use static shape for reduce memory consumption
andompesta Oct 22, 2024
40de55c
implemnet tokenizer into t5 engine
andompesta Oct 22, 2024
c8273c7
mix max_batch size to 8
andompesta Oct 22, 2024
b96fd96
add licence
andompesta Oct 22, 2024
6743bb7
add licence
andompesta Oct 22, 2024
852b444
enable trt runtime tracking
andompesta Oct 22, 2024
95f7822
add static-batch and static-shape options
andompesta Oct 22, 2024
8ac3f84
add cuda steam to load method
andompesta Oct 22, 2024
f93fc87
add inference code
andompesta Oct 22, 2024
528621a
add inference code
andompesta Oct 22, 2024
23e1236
enable static shape
andompesta Oct 22, 2024
dc326df
add `static_shape` option to reduce memory and `_build_engine` as sta…
andompesta Oct 22, 2024
7e3fe14
add `should_be_dtype` filed to handle output type conversion
andompesta Oct 22, 2024
41f18e7
from trtbuilder to trt_manager
andompesta Oct 22, 2024
12dee48
from TRTBuilder to TRTManager
andompesta Oct 23, 2024
45997a9
AE engine interface
andompesta Oct 23, 2024
bb9f468
`trt_to_torch_dtype_dict` as property
andompesta Oct 23, 2024
2bde369
clip engine inference
andompesta Oct 23, 2024
359572e
implement flux trt engine inference process
andompesta Oct 23, 2024
e3f0fd9
add scale_factor and shift_factor
andompesta Oct 23, 2024
d91bbde
removed `should_be_dtype`
andompesta Oct 23, 2024
df245db
removed `should_be_dtype`
andompesta Oct 23, 2024
33bc095
remove `should_be_dtype` from t5
andompesta Oct 23, 2024
c330491
add scale and shift factor
andompesta Oct 23, 2024
90b4f11
`max_batch` to 8
andompesta Oct 23, 2024
17c1f7d
implement `TRTManager`
andompesta Oct 23, 2024
811f2ff
from ae to vae to match DD
andompesta Oct 25, 2024
f4ae3ca
remove autocast
andompesta Oct 25, 2024
0fe7c84
`pooled_embeddings` to match DD naming for clip
andompesta Oct 25, 2024
f71091a
rename `flux` to `transformer` engine
andompesta Oct 25, 2024
4055a3e
from flux to transformer mixin
andompesta Oct 25, 2024
2b2bb5b
from flux to transforemer exporter
andompesta Oct 25, 2024
b088430
fix trtmanger with naming
andompesta Oct 25, 2024
82d658d
fix inputs names and dimentions. Nota that `img_ids` and `txt_ids` ar…
andompesta Oct 25, 2024
3708773
fix shape of inputs according to `text_maxlen` and batch_size
andompesta Oct 25, 2024
7737426
reduce max_batch
andompesta Oct 27, 2024
917d8ff
fix stage naming
andompesta Oct 27, 2024
6473ca1
add support for DD model
andompesta Oct 27, 2024
6d39ad5
add support for DD models
andompesta Oct 27, 2024
753129b
fix dtype configuration
andompesta Oct 28, 2024
149c27c
fix enginge dtype
andompesta Oct 28, 2024
55568bf
trensformers inference interface to match DD
andompesta Oct 28, 2024
4872169
vae inference script dtype mapping
andompesta Oct 28, 2024
41ee44c
remove dtype checks as multiples can be actives
andompesta Oct 28, 2024
a31161d
by default tf32 always active
andompesta Oct 28, 2024
3b91c51
fix trt enginges names
andompesta Nov 11, 2024
4ebca7d
add wrapper for fluxmodel to match DD onnx configuration
andompesta Nov 11, 2024
3e9f64f
add autocast back in to match DD setup
andompesta Nov 11, 2024
bb82e4b
fix dependencies for trt support
andompesta Nov 14, 2024
830358e
support trt
andompesta Nov 14, 2024
cdce3a3
add explicit kwargs
andompesta Nov 14, 2024
b789e05
vscode setup
andompesta Nov 14, 2024
8b07e6e
add setup instructions for trt
andompesta Nov 14, 2024
5ffd6d6
`trt` dependencies not part of `all`
andompesta Nov 14, 2024
766d878
from onnx_exporter to exporter
andompesta Nov 14, 2024
6d83690
hide onnx parameters
andompesta Nov 14, 2024
2458486
from onnx-exporter to exporter
andompesta Nov 14, 2024
80a52d7
exporter responsible to build trt engine and onnx exportr
andompesta Nov 14, 2024
adf2d46
hide onnx parameter
andompesta Nov 14, 2024
e82311f
remove build function from engine class
andompesta Nov 14, 2024
17f6562
remove unused import
andompesta Nov 14, 2024
2512bb2
remove space
andompesta Nov 14, 2024
86614a3
manage t5 and vae separately
andompesta Nov 14, 2024
f14de69
disable autocast
andompesta Nov 14, 2024
3410d34
stronglytyped t5
andompesta Nov 14, 2024
2422538
fix input type and max image size
andompesta Nov 14, 2024
9bef65b
max image size
andompesta Nov 14, 2024
a3bd8fc
T5 not strongly typed
andompesta Nov 14, 2024
e615fa0
testing
andompesta Nov 14, 2024
611efed
fix torch sycronize problem
andompesta Nov 14, 2024
13b1016
don't build already present engines
andompesta Nov 14, 2024
01b508c
remove torch save
andompesta Nov 14, 2024
f57b5a5
removed onnx dependencies
andompesta Nov 14, 2024
9cffa24
add trt dependencies
andompesta Nov 14, 2024
63e29cc
remove trt dependencies from toml
andompesta Nov 14, 2024
c978cc3
rename requirements and fix readme
andompesta Nov 14, 2024
3087c60
remove unused files
andompesta Nov 14, 2024
5c2cba1
fix import format
andompesta Nov 14, 2024
08fbb60
remove comments
andompesta Nov 14, 2024
1b4a41a
add gitignore
andompesta Nov 15, 2024
a404144
reset dependencies
andompesta Nov 15, 2024
a8b8478
add hidden setup files
andompesta Nov 15, 2024
8fa1d22
solve ruff check
andompesta Nov 15, 2024
3f20508
fix imports with rufs
andompesta Nov 15, 2024
7662313
run ruff formatter
andompesta Nov 15, 2024
4691502
update gitignore
andompesta Nov 15, 2024
deb5633
simplify dependencies
andompesta Nov 18, 2024
1de2799
remove gitignore
andompesta Nov 18, 2024
64cbb8f
add cli formatting
andompesta Nov 18, 2024
fd1455e
fix import orders
andompesta Nov 18, 2024
095ee89
Merge pull request #1 from andompesta/add-trt-support-push
andompesta Nov 18, 2024
3d3741e
simplify dependencies
andompesta Nov 18, 2024
f31ffd4
solve vae quality issue
andompesta Nov 26, 2024
728c018
Merge branch 'main' of https://github.com/black-forest-labs/flux
andompesta Nov 26, 2024
1cd9476
Merge branch 'main' into add-trt-support
andompesta Nov 26, 2024
bee6c45
Merge branch 'main' into add-trt-support-cli-conflict
andompesta Nov 26, 2024
f80058f
fix ruff format
andompesta Nov 26, 2024
079778f
fix merge changes
andompesta Nov 26, 2024
a5986b5
format and sort src/flux/cli
andompesta Nov 26, 2024
c7fdb64
fix merge conflicts
andompesta Nov 26, 2024
74c4c7a
Merge pull request #2 from andompesta/add-trt-support-cli-conflict
andompesta Nov 26, 2024
631d039
Merge branch 'main' of https://github.com/black-forest-labs/flux into…
Jan 14, 2025
e29b5eb
add trt import
andompesta Jan 14, 2025
c962da1
add static shape support (not completed)
andompesta Jan 14, 2025
5801579
remove fp8 support
andompesta Jan 14, 2025
ff29e9a
add static shape
andompesta Jan 14, 2025
973353c
add static shape to t5
andompesta Jan 14, 2025
200bfe5
add static shape to transformer
andompesta Jan 14, 2025
eb0217b
remove model opt code
andompesta Jan 14, 2025
05e2378
enable offloading with trt engines
andompesta Jan 14, 2025
c01c47a
add `stream` as part of `init_runtime`
andompesta Jan 14, 2025
45cfc62
enable offloading
andompesta Jan 14, 2025
31d195d
`allocate_buffers` moved to call
andompesta Jan 14, 2025
9fd1008
formatting
andompesta Jan 14, 2025
f9d3fad
add capability to compute `img_dim`
andompesta Jan 14, 2025
57174d4
enable dynamic or static-shape
andompesta Jan 14, 2025
449980f
split base-engine and engine class
andompesta Jan 15, 2025
6818690
clip as engine
andompesta Jan 15, 2025
e203148
t5 as engine
andompesta Jan 15, 2025
632531f
transformer as engine
andompesta Jan 15, 2025
51f7790
VAEDEcoder as engine and VAEEngine as BaseEngine
andompesta Jan 15, 2025
3aeefa4
from vae to vae_decoder, vae_encoder and vae
andompesta Jan 15, 2025
880cb67
use `set_stream` and fix activate call
andompesta Jan 15, 2025
eeb039d
fix import and remove stages in TRTManager
andompesta Jan 15, 2025
6316f51
from BaseEngine to BaseEngine and Engine
andompesta Jan 15, 2025
f071acc
fix imports
andompesta Jan 15, 2025
db3cb60
add trt support to cli_controlnet
andompesta Jan 15, 2025
3e180f8
add vae_encoder to support controlnet
andompesta Jan 15, 2025
d56ae30
refactor vae engine to use load() and activate() functions
andompesta Jan 15, 2025
c16984d
implement vae_encoder_exporter. Not tested
andompesta Jan 15, 2025
4c24f59
fix imports
andompesta Jan 15, 2025
0f6af2d
add static_batch and static_shape to cli.py as additional option ?
andompesta Jan 24, 2025
acd1d13
update dependencies
andompesta Jan 24, 2025
c3e3f23
revert formatting
andompesta Jan 24, 2025
e2b41eb
Merge branch 'add-trt-support' of github.com:andompesta/flux into add…
andompesta Jan 24, 2025
55219cb
Merge branch 'add-trt-support' into add-trt-support-controlnet
andompesta Jan 24, 2025
6f58632
Merge pull request #3 from andompesta/add-trt-support-controlnet
andompesta Jan 24, 2025
81b807e
from Self to Any to be compatible with pytorch 3.10
andompesta Jan 25, 2025
4ce7974
from `vae_decoder` to `vae` for compatibility with oss engines
andompesta Jan 25, 2025
5d0b780
missing torch import
andompesta Jan 25, 2025
6394f76
Merge pull request #4 from andompesta/add-trt-support-controlnet
andompesta Jan 25, 2025
1d11bb4
add `scale_factor` and `shift_factor` to VAE-encoder
andompesta Jan 28, 2025
48533df
add check if vae is traced
andompesta Jan 28, 2025
06efe08
offload while tracing
andompesta Jan 28, 2025
4dbc7d2
default `text_maxlen` set to dev size instead of schnell
andompesta Jan 28, 2025
0071782
remove line
andompesta Jan 28, 2025
a76d97b
add warnign when text_maxlen is not read from t5
andompesta Jan 28, 2025
40fe2df
fix imports
andompesta Jan 28, 2025
fa62fa8
Merge pull request #5 from andompesta/add-trt-support-controlnet
andompesta Jan 28, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 65 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,22 @@ source .venv/bin/activate
pip install -e ".[all]"
```

## Local installation with TRT support

```bash
docker pull nvcr.io/nvidia/pytorch:24.10-py3
cd $HOME && git clone https://github.com/black-forest-labs/flux
cd $HOME/flux
docker run --rm -it --gpus all -v $PWD:/workspace/flux nvcr.io/nvidia/pytorch:24.10-py3 /bin/bash
# inside container
cd /workspace/flux
pip install -e ".[all]"
pip install -r trt_requirements.txt
```

### Models

We are offering an extensive suite of models. For more information about the individual models, please refer to the link under **Usage**.
We are offering an extensive suite of models. For more information about the invidual models, please refer to the link under **Usage**.

| Name | Usage | HuggingFace repo | License |
| --------------------------- | ---------------------------------------------------------- | -------------------------------------------------------------- | --------------------------------------------------------------------- |
Expand All @@ -42,6 +55,57 @@ We are offering an extensive suite of models. For more information about the ind

The weights of the autoencoder are also released under [apache-2.0](https://huggingface.co/datasets/choosealicense/licenses/blob/main/markdown/apache-2.0.md) and can be found in the HuggingFace repos above.

We also offer a Gradio-based demo for an interactive experience. To run the Gradio demo:

```bash
python demo_gr.py --name flux-schnell --device cuda
```

Options:

- `--name`: Choose the model to use (options: "flux-schnell", "flux-dev")
- `--device`: Specify the device to use (default: "cuda" if available, otherwise "cpu")
- `--offload`: Offload model to CPU when not in use
- `--share`: Create a public link to your demo

To run the demo with the dev model and create a public link:

```bash
python demo_gr.py --name flux-dev --share
```

## Diffusers integration

`FLUX.1 [schnell]` and `FLUX.1 [dev]` are integrated with the [🧨 diffusers](https://github.com/huggingface/diffusers) library. To use it with diffusers, install it:

```shell
pip install git+https://github.com/huggingface/diffusers.git
```

Then you can use `FluxPipeline` to run the model

```python
import torch
from diffusers import FluxPipeline

model_id = "black-forest-labs/FLUX.1-schnell" #you can also use `black-forest-labs/FLUX.1-dev`

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload() #save some VRAM by offloading the model to CPU. Remove this if you have enough GPU power

prompt = "A cat holding a sign that says hello world"
seed = 42
image = pipe(
prompt,
output_type="pil",
num_inference_steps=4, #use a larger number if you are using [dev]
generator=torch.Generator("cpu").manual_seed(seed)
).images[0]
image.save("flux-schnell.png")
```

To learn more check out the [diffusers](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux) documentation

## API usage

Our API offers access to our models. It is documented here:
Expand Down
3 changes: 0 additions & 3 deletions demo_gr.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

NSFW_THRESHOLD = 0.85


def get_models(name: str, device: torch.device, offload: bool, is_schnell: bool):
t5 = load_t5(device, max_length=256 if is_schnell else 512)
clip = load_clip(device)
Expand All @@ -24,7 +23,6 @@ def get_models(name: str, device: torch.device, offload: bool, is_schnell: bool)
nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
return model, ae, t5, clip, nsfw_classifier


class FluxGenerator:
def __init__(self, model_name: str, device: str, offload: bool):
self.device = torch.device(device)
Expand Down Expand Up @@ -153,7 +151,6 @@ def generate_image(
exif_data[ExifTags.Base.Model] = self.model_name
if add_sampling_metadata:
exif_data[ExifTags.Base.ImageDescription] = prompt

img.save(filename, format="jpeg", exif=exif_data, quality=95, subsampling=0)

return img, str(opts.seed), filename, None
Expand Down
67 changes: 65 additions & 2 deletions src/flux/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
from glob import iglob

import torch
from cuda import cudart
from fire import Fire
from transformers import pipeline

from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
from flux.trt.trt_manager import TRTManager
from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image

NSFW_THRESHOLD = 0.85
Expand All @@ -25,7 +27,9 @@ class SamplingOptions:


def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
user_question = (
"Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
)
usage = (
"Usage: Either write your prompt directly, leave this field empty "
"to repeat the prompt or write a command starting with a slash:\n"
Expand Down Expand Up @@ -108,6 +112,8 @@ def main(
offload: bool = False,
output_dir: str = "output",
add_sampling_metadata: bool = True,
trt: bool = False,
**kwargs: dict | None,
):
"""
Sample the flux model. Either interactively (set `--loop`) or run for a
Expand All @@ -126,6 +132,8 @@ def main(
loop: start an interactive session and sample multiple times
guidance: guidance value used for guidance distillation
add_sampling_metadata: Add the prompt to the image Exif metadata
trt: use TensorRT backend for optimized inference
kwargs: additional arguments for TensorRT support
"""
nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)

Expand Down Expand Up @@ -158,6 +166,57 @@ def main(
model = load_flow_model(name, device="cpu" if offload else torch_device)
ae = load_ae(name, device="cpu" if offload else torch_device)

if trt:
# offload to CPU to save memory
ae = ae.cpu()
model = model.cpu()
clip = clip.cpu()
t5 = t5.cpu()

torch.cuda.empty_cache()

trt_ctx_manager = TRTManager(
bf16=True,
device=torch_device,
static_batch=kwargs.get("static_batch", True),
static_shape=kwargs.get("static_shape", True),
)
ae.decoder.params = ae.params
engines = trt_ctx_manager.load_engines(
models={
"clip": clip,
"transformer": model,
"t5": t5,
"vae": ae.decoder,
},
engine_dir=os.environ.get("TRT_ENGINE_DIR", "./engines"),
onnx_dir=os.environ.get("ONNX_DIR", "./onnx"),
opt_image_height=height,
opt_image_width=width,
)

torch.cuda.synchronize()

trt_ctx_manager.init_runtime()
# TODO: refactor. stream should be part of engine constructor maybe !!
for _, engine in engines.items():
engine.set_stream(stream=trt_ctx_manager.stream)

if not offload:
for _, engine in engines.items():
engine.load()

calculate_max_device_memory = trt_ctx_manager.calculate_max_device_memory(engines)
_, shared_device_memory = cudart.cudaMalloc(calculate_max_device_memory)

for _, engine in engines.items():
engine.activate(device=torch_device, device_memory=shared_device_memory)

ae = engines["vae"]
model = engines["transformer"]
clip = engines["clip"]
t5 = engines["t5"]

rng = torch.Generator(device="cpu")
opts = SamplingOptions(
prompt=prompt,
Expand Down Expand Up @@ -192,7 +251,9 @@ def main(
torch.cuda.empty_cache()
t5, clip = t5.to(torch_device), clip.to(torch_device)
inp = prepare(t5, clip, x, prompt=opts.prompt)
timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
timesteps = get_schedule(
opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell")
)

# offload TEs to CPU, load model to gpu
if offload:
Expand Down Expand Up @@ -229,6 +290,8 @@ def main(
else:
opts = None

if trt:
trt_ctx_manager.stop_runtime()

def app():
Fire(main)
Expand Down
50 changes: 50 additions & 0 deletions src/flux/cli_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
from glob import iglob

import torch
from cuda import cudart
from fire import Fire
from transformers import pipeline

from flux.modules.image_embedders import CannyImageEncoder, DepthImageEncoder
from flux.sampling import denoise, get_noise, get_schedule, prepare_control, unpack
from flux.trt.trt_manager import TRTManager
from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image


Expand Down Expand Up @@ -174,6 +176,8 @@ def main(
add_sampling_metadata: bool = True,
img_cond_path: str = "assets/robot.webp",
lora_scale: float | None = 0.85,
trt: bool = False,
**kwargs: dict | None,
):
"""
Sample the flux model. Either interactively (set `--loop`) or run for a
Expand All @@ -192,6 +196,7 @@ def main(
guidance: guidance value used for guidance distillation
add_sampling_metadata: Add the prompt to the image Exif metadata
img_cond_path: path to conditioning image (jpeg/png/webp)
trt: use TensorRT backend for optimized inference
"""
nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)

Expand Down Expand Up @@ -234,6 +239,7 @@ def main(

# set lora scale
if "lora" in name and lora_scale is not None:
assert not trt, "TRT does not support LORA yet"
for _, module in model.named_modules():
if hasattr(module, "set_scale"):
module.set_scale(lora_scale)
Expand All @@ -245,6 +251,50 @@ def main(
else:
raise NotImplementedError()

if trt:
trt_ctx_manager = TRTManager(
bf16=True,
device=torch_device,
static_batch=kwargs.get("static_batch", True),
static_shape=kwargs.get("static_shape", True),
)
ae.decoder.params = ae.params
ae.encoder.params = ae.params
engines = trt_ctx_manager.load_engines(
models={
"clip": clip.cpu(),
"transformer": model.cpu(),
"t5": t5.cpu(),
"vae": ae.decoder.cpu(),
"vae_encoder": ae.encoder.cpu(),
},
engine_dir=os.environ.get("TRT_ENGINE_DIR", "./engines"),
onnx_dir=os.environ.get("ONNX_DIR", "./onnx"),
opt_image_height=height,
opt_image_width=width,
)
torch.cuda.synchronize()

trt_ctx_manager.init_runtime()
# TODO: refactor. stream should be part of engine constructor maybe !!
for _, engine in engines.items():
engine.set_stream(stream=trt_ctx_manager.stream)

if not offload:
for _, engine in engines.items():
engine.load()

calculate_max_device_memory = trt_ctx_manager.calculate_max_device_memory(engines)
_, shared_device_memory = cudart.cudaMalloc(calculate_max_device_memory)

for _, engine in engines.items():
engine.activate(device=torch_device, device_memory=shared_device_memory)

ae = engines["vae"]
model = engines["transformer"]
clip = engines["clip"]
t5 = engines["t5"]

rng = torch.Generator(device="cpu")
opts = SamplingOptions(
prompt=prompt,
Expand Down
2 changes: 1 addition & 1 deletion src/flux/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:

def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
assert dim % 2 == 0
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim
omega = 1.0 / (theta**scale)
out = torch.einsum("...n,d->...nd", pos, omega)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
Expand Down
6 changes: 6 additions & 0 deletions src/flux/modules/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,9 @@ def __init__(
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)

def forward(self, z: Tensor) -> Tensor:
# get dtype for proper tracing
upscale_dtype = next(self.up.parameters()).dtype

# z to block_in
h = self.conv_in(z)

Expand All @@ -243,6 +246,8 @@ def forward(self, z: Tensor) -> Tensor:
h = self.mid.attn_1(h)
h = self.mid.block_2(h)

# cast to proper dtype
h = h.to(upscale_dtype)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
Expand Down Expand Up @@ -277,6 +282,7 @@ def forward(self, z: Tensor) -> Tensor:
class AutoEncoder(nn.Module):
def __init__(self, params: AutoEncoderParams):
super().__init__()
self.params = params
self.encoder = Encoder(
resolution=params.resolution,
in_channels=params.in_channels,
Expand Down
Empty file added src/flux/trt/__init__.py
Empty file.
32 changes: 32 additions & 0 deletions src/flux/trt/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from flux.trt.engine.base_engine import BaseEngine, Engine
from flux.trt.engine.clip_engine import CLIPEngine
from flux.trt.engine.t5_engine import T5Engine
from flux.trt.engine.transformer_engine import TransformerEngine
from flux.trt.engine.vae_engine import VAEEngine, VAEDecoder, VAEEncoder

__all__ = [
"BaseEngine",
"Engine",
"CLIPEngine",
"TransformerEngine",
"T5Engine",
"VAEEngine",
"VAEDecoder",
"VAEEncoder",
]
Loading