diff --git a/.github/workflows/typos.yaml b/.github/workflows/typos.yaml index eb859574a..e8b06483f 100644 --- a/.github/workflows/typos.yaml +++ b/.github/workflows/typos.yaml @@ -18,4 +18,4 @@ jobs: - uses: actions/checkout@v4 - name: typos-action - uses: crate-ci/typos@v1.18.2 + uses: crate-ci/typos@v1.19.0 diff --git a/.release b/.release index f242ab11b..06cd02725 100644 --- a/.release +++ b/.release @@ -1 +1 @@ -v23.0.15 \ No newline at end of file +v23.1.0 \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json index d412e4162..45378908f 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -2,5 +2,5 @@ "python.linting.enabled": true, "python.formatting.provider": "yapf", "DockerRun.DisableDockerrc": true, - "augment.enableAutomaticCompletions": false + "augment.enableAutomaticCompletions": true } \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index bafee9fed..0ff872a34 100644 --- a/Dockerfile +++ b/Dockerfile @@ -17,10 +17,10 @@ ARG PIP_NO_WARN_SCRIPT_LOCATION=0 ARG PIP_ROOT_USER_ACTION="ignore" # Install build dependencies -RUN apt-get update && apt-get upgrade -y && \ - apt-get install -y --no-install-recommends python3-launchpadlib git curl && \ - apt-get clean && \ - rm -rf /var/lib/apt/lists/* +RUN --mount=type=cache,id=apt-$TARGETARCH$TARGETVARIANT,sharing=locked,target=/var/cache/apt \ + --mount=type=cache,id=aptlists-$TARGETARCH$TARGETVARIANT,sharing=locked,target=/var/lib/apt/lists \ + apt-get update && apt-get upgrade -y && \ + apt-get install -y --no-install-recommends python3-launchpadlib git curl # Install PyTorch and TensorFlow # The versions must align and be in sync with the requirements_linux_docker.txt @@ -44,24 +44,49 @@ RUN --mount=type=cache,id=pip-$TARGETARCH$TARGETVARIANT,sharing=locked,target=/r # Replace pillow with pillow-simd (Only for x86) ARG TARGETPLATFORM -RUN if [ "$TARGETPLATFORM" = "linux/amd64" ]; then \ +RUN --mount=type=cache,id=apt-$TARGETARCH$TARGETVARIANT,sharing=locked,target=/var/cache/apt \ + --mount=type=cache,id=aptlists-$TARGETARCH$TARGETVARIANT,sharing=locked,target=/var/lib/apt/lists \ + if [ "$TARGETPLATFORM" = "linux/amd64" ]; then \ apt-get update && apt-get install -y --no-install-recommends zlib1g-dev libjpeg62-turbo-dev build-essential && \ - apt-get clean && \ - rm -rf /var/lib/apt/lists/* && \ pip uninstall -y pillow && \ CC="cc -mavx2" pip install -U --force-reinstall pillow-simd; \ fi FROM python:3.10-slim as final +ARG TARGETARCH +ARG TARGETVARIANT + ENV NVIDIA_VISIBLE_DEVICES all ENV NVIDIA_DRIVER_CAPABILITIES compute,utility +WORKDIR /tmp + +ENV CUDA_VERSION=12.1.1 +ENV NV_CUDA_CUDART_VERSION=12.1.105-1 +ENV NVIDIA_REQUIRE_CUDA=cuda>=12.1 +ENV NV_CUDA_COMPAT_PACKAGE=cuda-compat-12-1 + +# Install CUDA partially +ADD https://developer.download.nvidia.com/compute/cuda/repos/debian11/x86_64/cuda-keyring_1.0-1_all.deb . +RUN --mount=type=cache,id=apt-$TARGETARCH$TARGETVARIANT,sharing=locked,target=/var/cache/apt \ + --mount=type=cache,id=aptlists-$TARGETARCH$TARGETVARIANT,sharing=locked,target=/var/lib/apt/lists \ + dpkg -i cuda-keyring_1.0-1_all.deb && \ + rm cuda-keyring_1.0-1_all.deb && \ + sed -i 's/^Components: main$/& contrib/' /etc/apt/sources.list.d/debian.sources && \ + apt-get update && \ + apt-get install -y --no-install-recommends \ + # Installing the whole CUDA typically increases the image size by approximately **8GB**. + # To decrease the image size, we opt to install only the necessary libraries. + # Here is the package list for your reference: https://developer.download.nvidia.com/compute/cuda/repos/debian11/x86_64 + # !If you experience any related issues, replace the following line with `cuda-12-1` to obtain the complete CUDA package. + cuda-cudart-12-1=${NV_CUDA_CUDART_VERSION} ${NV_CUDA_COMPAT_PACKAGE} libcusparse-12-1 libnvjitlink-12-1 + # Install runtime dependencies -RUN apt-get update && \ - apt-get install -y --no-install-recommends libgl1 libglib2.0-0 libjpeg62 libtcl8.6 libtk8.6 libgoogle-perftools-dev dumb-init && \ - apt-get clean && \ - rm -rf /var/lib/apt/lists/* +RUN --mount=type=cache,id=apt-$TARGETARCH$TARGETVARIANT,sharing=locked,target=/var/cache/apt \ + --mount=type=cache,id=aptlists-$TARGETARCH$TARGETVARIANT,sharing=locked,target=/var/lib/apt/lists \ + apt-get update && \ + apt-get install -y --no-install-recommends libgl1 libglib2.0-0 libjpeg62 libtcl8.6 libtk8.6 libgoogle-perftools-dev dumb-init # Fix missing libnvinfer7 RUN ln -s /usr/lib/x86_64-linux-gnu/libnvinfer.so /usr/lib/x86_64-linux-gnu/libnvinfer.so.7 && \ @@ -84,8 +109,9 @@ COPY --link --chmod=775 LICENSE.md /licenses/LICENSE.md COPY --link --chown=$UID:0 --chmod=775 --from=build /root/.local /home/$UID/.local COPY --link --chown=$UID:0 --chmod=775 . /app -ENV PATH="/home/$UID/.local/bin:$PATH" +ENV PATH="/usr/local/cuda/lib:/usr/local/cuda/lib64:/home/$UID/.local/bin:$PATH" ENV PYTHONPATH="${PYTHONPATH}:/home/$UID/.local/lib/python3.10/site-packages" +ENV LD_LIBRARY_PATH="/usr/local/cuda/lib:/usr/local/cuda/lib64:${LD_LIBRARY_PATH}" ENV LD_PRELOAD=libtcmalloc.so ENV PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python @@ -103,7 +129,7 @@ STOPSIGNAL SIGINT # Use dumb-init as PID 1 to handle signals properly ENTRYPOINT ["dumb-init", "--"] -CMD ["python3", "kohya_gui.py", "--listen", "0.0.0.0", "--server_port", "7860"] +CMD ["python3", "kohya_gui.py", "--listen", "0.0.0.0", "--server_port", "7860", "--headless"] ARG VERSION ARG RELEASE diff --git a/README.md b/README.md index e818fe406..1c6e57fbb 100644 --- a/README.md +++ b/README.md @@ -36,9 +36,14 @@ The GUI allows you to set the training parameters and generate and run the requi - [Troubleshooting](#troubleshooting) - [Page File Limit](#page-file-limit) - [No module called tkinter](#no-module-called-tkinter) + - [LORA Training on TESLA V100 - GPU Utilization Issue](#lora-training-on-tesla-v100---gpu-utilization-issue) + - [Issue Summary](#issue-summary) + - [Potential Solutions](#potential-solutions) - [SDXL training](#sdxl-training) + - [Masked loss](#masked-loss) - [Change History](#change-history) - - [2024/03/20 (v23.0.15)](#20240320-v23015) + - [2024/04/07 (v23.1.0)](#20240407-v2310) + - [2024/03/21 (v23.0.15)](#20240321-v23015) - [2024/03/19 (v23.0.14)](#20240319-v23014) - [2024/03/19 (v23.0.13)](#20240319-v23013) - [2024/03/16 (v23.0.12)](#20240316-v23012) @@ -46,16 +51,6 @@ The GUI allows you to set the training parameters and generate and run the requi - [Software Updates](#software-updates) - [Recommendations for Users](#recommendations-for-users) - [2024/03/13 (v23.0.11)](#20240313-v23011) - - [2024/03/13 (v23.0.9)](#20240313-v2309) - - [2024/03/12 (v23.0.8)](#20240312-v2308) - - [2024/03/12 (v23.0.7)](#20240312-v2307) - - [2024/03/11 (v23.0.6)](#20240311-v2306) - - [2024/03/11 (v23.0.5)](#20240311-v2305) - - [2024/03/10 (v23.0.4)](#20240310-v2304) - - [2024/03/10 (v23.0.3)](#20240310-v2303) - - [2024/03/10 (v23.0.2)](#20240310-v2302) - - [2024/03/09 (v23.0.1)](#20240309-v2301) - - [2024/03/02 (v23.0.0)](#20240302-v2300) ## 🦒 Colab @@ -330,10 +325,17 @@ gui.sh --listen 127.0.0.1 --server_port 7860 --inbrowser --share ## Custom Path Defaults -You can now specify custom paths more easily: +The repository now provides a default configuration file named `config.toml`. This file is a template that you can customize to suit your needs. -- Simply copy the `config example.toml` file located in the root directory of the repository to `config.toml`. -- Edit the `config.toml` file to adjust paths and settings according to your preferences. +To use the default configuration file, follow these steps: + +1. Copy the `config example.toml` file from the root directory of the repository to `config.toml`. +2. Open the `config.toml` file in a text editor. +3. Modify the paths and settings as per your requirements. + +This approach allows you to easily adjust the configuration to suit your specific needs to open the desired default folders for each type of folder/file input supported in the GUI. + +You can specify the path to your config.toml (or any other name you like) when running the GUI. For instance: ./gui.bat --config c:\my_config.toml ## LoRA @@ -376,12 +378,119 @@ If you encounter an X error related to the page file, you may need to increase t If you encounter an error indicating that the module `tkinter` is not found, try reinstalling Python 3.10 on your system. +### LORA Training on TESLA V100 - GPU Utilization Issue + +#### Issue Summary + +When training LORA on a TESLA V100, users reported low GPU utilization. Additionally, there was difficulty in specifying GPUs other than the default for training. + +#### Potential Solutions + +- **GPU Selection:** Users can specify GPU IDs in the setup configuration to select the desired GPUs for training. +- **Improving GPU Load:** Utilizing `adamW8bit` optimizer and increasing the batch size can help achieve 70-80% GPU utilization without exceeding GPU memory limits. + ## SDXL training The documentation in this section will be moved to a separate document later. +## Masked loss + +The masked loss is supported in each training script. To enable the masked loss, specify the `--masked_loss` option. + +The feature is not fully tested, so there may be bugs. If you find any issues, please open an Issue. + +ControlNet dataset is used to specify the mask. The mask images should be the RGB images. The pixel value 255 in R channel is treated as the mask (the loss is calculated only for the pixels with the mask), and 0 is treated as the non-mask. The pixel values 0-255 are converted to 0-1 (i.e., the pixel value 128 is treated as the half weight of the loss). See details for the dataset specification in the [LLLite documentation](./docs/train_lllite_README.md#preparing-the-dataset). + ## Change History +### 2024/04/07 (v23.1.0) + +- Update sd-scripts to 0.8.7 + - The default value of `huber_schedule` in Scheduled Huber Loss is changed from `exponential` to `snr`, which is expected to give better results. + + - Highlights + - The dependent libraries are updated. Please see [Upgrade](#upgrade) and update the libraries. + - Especially `imagesize` is newly added, so if you cannot update the libraries immediately, please install with `pip install imagesize==1.4.1` separately. + - `bitsandbytes==0.43.0`, `prodigyopt==1.0`, `lion-pytorch==0.0.6` are included in the requirements.txt. + - `bitsandbytes` no longer requires complex procedures as it now officially supports Windows. + - Also, the PyTorch version is updated to 2.1.2 (PyTorch does not need to be updated immediately). In the upgrade procedure, PyTorch is not updated, so please manually install or update torch, torchvision, xformers if necessary (see [Upgrade PyTorch](#upgrade-pytorch)). + - When logging to wandb is enabled, the entire command line is exposed. Therefore, it is recommended to write wandb API key and HuggingFace token in the configuration file (`.toml`). Thanks to bghira for raising the issue. + - A warning is displayed at the start of training if such information is included in the command line. + - Also, if there is an absolute path, the path may be exposed, so it is recommended to specify a relative path or write it in the configuration file. In such cases, an INFO log is displayed. + - See [#1123](https://github.com/kohya-ss/sd-scripts/pull/1123) and PR [#1240](https://github.com/kohya-ss/sd-scripts/pull/1240) for details. + - Colab seems to stop with log output. Try specifying `--console_log_simple` option in the training script to disable rich logging. + - Other improvements include the addition of masked loss, scheduled Huber Loss, DeepSpeed support, dataset settings improvements, and image tagging improvements. See below for details. + + - Training scripts + - `train_network.py` and `sdxl_train_network.py` are modified to record some dataset settings in the metadata of the trained model (`caption_prefix`, `caption_suffix`, `keep_tokens_separator`, `secondary_separator`, `enable_wildcard`). + - Fixed a bug that U-Net and Text Encoders are included in the state in `train_network.py` and `sdxl_train_network.py`. The saving and loading of the state are faster, the file size is smaller, and the memory usage when loading is reduced. + - DeepSpeed is supported. PR [#1101](https://github.com/kohya-ss/sd-scripts/pull/1101) and [#1139](https://github.com/kohya-ss/sd-scripts/pull/1139) Thanks to BootsofLagrangian! See PR [#1101](https://github.com/kohya-ss/sd-scripts/pull/1101) for details. + - The masked loss is supported in each training script. PR [#1207](https://github.com/kohya-ss/sd-scripts/pull/1207) See [Masked loss](#masked-loss) for details. + - Scheduled Huber Loss has been introduced to each training scripts. PR [#1228](https://github.com/kohya-ss/sd-scripts/pull/1228/) Thanks to kabachuha for the PR and cheald, drhead, and others for the discussion! See the PR and [Scheduled Huber Loss](./docs/train_lllite_README.md#scheduled-huber-loss) for details. + - The options `--noise_offset_random_strength` and `--ip_noise_gamma_random_strength` are added to each training script. These options can be used to vary the noise offset and ip noise gamma in the range of 0 to the specified value. PR [#1177](https://github.com/kohya-ss/sd-scripts/pull/1177) Thanks to KohakuBlueleaf! + - The options `--save_state_on_train_end` are added to each training script. PR [#1168](https://github.com/kohya-ss/sd-scripts/pull/1168) Thanks to gesen2egee! + - The options `--sample_every_n_epochs` and `--sample_every_n_steps` in each training script now display a warning and ignore them when a number less than or equal to `0` is specified. Thanks to S-Del for raising the issue. + + - Dataset settings + - The [English version of the dataset settings documentation](./docs/config_README-en.md) is added. PR [#1175](https://github.com/kohya-ss/sd-scripts/pull/1175) Thanks to darkstorm2150! + - The `.toml` file for the dataset config is now read in UTF-8 encoding. PR [#1167](https://github.com/kohya-ss/sd-scripts/pull/1167) Thanks to Horizon1704! + - Fixed a bug that the last subset settings are applied to all images when multiple subsets of regularization images are specified in the dataset settings. The settings for each subset are correctly applied to each image. PR [#1205](https://github.com/kohya-ss/sd-scripts/pull/1205) Thanks to feffy380! + - Some features are added to the dataset subset settings. + - `secondary_separator` is added to specify the tag separator that is not the target of shuffling or dropping. + - Specify `secondary_separator=";;;"`. When you specify `secondary_separator`, the part is not shuffled or dropped. + - `enable_wildcard` is added. When set to `true`, the wildcard notation `{aaa|bbb|ccc}` can be used. The multi-line caption is also enabled. + - `keep_tokens_separator` is updated to be used twice in the caption. When you specify `keep_tokens_separator="|||"`, the part divided by the second `|||` is not shuffled or dropped and remains at the end. + - The existing features `caption_prefix` and `caption_suffix` can be used together. `caption_prefix` and `caption_suffix` are processed first, and then `enable_wildcard`, `keep_tokens_separator`, shuffling and dropping, and `secondary_separator` are processed in order. + - See [Dataset config](./docs/config_README-en.md) for details. + - The dataset with DreamBooth method supports caching image information (size, caption). PR [#1178](https://github.com/kohya-ss/sd-scripts/pull/1178) and [#1206](https://github.com/kohya-ss/sd-scripts/pull/1206) Thanks to KohakuBlueleaf! See [DreamBooth method specific options](./docs/config_README-en.md#dreambooth-specific-options) for details. + + - Image tagging + - The support for v3 repositories is added to `tag_image_by_wd14_tagger.py` (`--onnx` option only). PR [#1192](https://github.com/kohya-ss/sd-scripts/pull/1192) Thanks to sdbds! + - Onnx may need to be updated. Onnx is not installed by default, so please install or update it with `pip install onnx==1.15.0 onnxruntime-gpu==1.17.1` etc. Please also check the comments in `requirements.txt`. + - The model is now saved in the subdirectory as `--repo_id` in `tag_image_by_wd14_tagger.py` . This caches multiple repo_id models. Please delete unnecessary files under `--model_dir`. + - Some options are added to `tag_image_by_wd14_tagger.py`. + - Some are added in PR [#1216](https://github.com/kohya-ss/sd-scripts/pull/1216) Thanks to Disty0! + - Output rating tags `--use_rating_tags` and `--use_rating_tags_as_last_tag` + - Output character tags first `--character_tags_first` + - Expand character tags and series `--character_tag_expand` + - Specify tags to output first `--always_first_tags` + - Replace tags `--tag_replacement` + - See [Tagging documentation](./docs/wd14_tagger_README-en.md) for details. + - Fixed an error when specifying `--beam_search` and a value of 2 or more for `--num_beams` in `make_captions.py`. + + - About Masked loss + The masked loss is supported in each training script. To enable the masked loss, specify the `--masked_loss` option. + + The feature is not fully tested, so there may be bugs. If you find any issues, please open an Issue. + + ControlNet dataset is used to specify the mask. The mask images should be the RGB images. The pixel value 255 in R channel is treated as the mask (the loss is calculated only for the pixels with the mask), and 0 is treated as the non-mask. The pixel values 0-255 are converted to 0-1 (i.e., the pixel value 128 is treated as the half weight of the loss). See details for the dataset specification in the [LLLite documentation](./docs/train_lllite_README.md#preparing-the-dataset). + + - About Scheduled Huber Loss + Scheduled Huber Loss has been introduced to each training scripts. This is a method to improve robustness against outliers or anomalies (data corruption) in the training data. + + With the traditional MSE (L2) loss function, the impact of outliers could be significant, potentially leading to a degradation in the quality of generated images. On the other hand, while the Huber loss function can suppress the influence of outliers, it tends to compromise the reproduction of fine details in images. + + To address this, the proposed method employs a clever application of the Huber loss function. By scheduling the use of Huber loss in the early stages of training (when noise is high) and MSE in the later stages, it strikes a balance between outlier robustness and fine detail reproduction. + + Experimental results have confirmed that this method achieves higher accuracy on data containing outliers compared to pure Huber loss or MSE. The increase in computational cost is minimal. + + The newly added arguments loss_type, huber_schedule, and huber_c allow for the selection of the loss function type (Huber, smooth L1, MSE), scheduling method (exponential, constant, SNR), and Huber's parameter. This enables optimization based on the characteristics of the dataset. + + See PR [#1228](https://github.com/kohya-ss/sd-scripts/pull/1228/) for details. + + - `loss_type`: Specify the loss function type. Choose `huber` for Huber loss, `smooth_l1` for smooth L1 loss, and `l2` for MSE loss. The default is `l2`, which is the same as before. + - `huber_schedule`: Specify the scheduling method. Choose `exponential`, `constant`, or `snr`. The default is `snr`. + - `huber_c`: Specify the Huber's parameter. The default is `0.1`. + + Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates.` + +- Added GUI support for the new parameters listed above. +- Moved accelerate launch parameters to a new `Accelerate launch` accordion above the `Model` accordion. +- Added support for `Debiased Estimation loss` to Dreambooth settings. +- Added support for "Dataset Preparation" defaults via the config.toml file. +- Added a field to allow for the input of extra accelerate launch arguments. +- Added new caption tool from https://github.com/kainatquaderee + ### 2024/03/21 (v23.0.15) - Add support for toml dataset configuration fole to all trainers @@ -431,50 +540,3 @@ The documentation in this section will be moved to a separate document later. - Increase icon size. - More setup fixes. - -### 2024/03/13 (v23.0.9) - -- Reworked how setup can be run to improve Stability Matrix support. -- Added support for huggingface-based vea path. - -### 2024/03/12 (v23.0.8) - -- Add the ability to create output and logs folder if it does not exist - -### 2024/03/12 (v23.0.7) - -- Fixed minor issues related to functions and file paths. - -### 2024/03/11 (v23.0.6) - -- Fixed an issue with PYTHON paths that have "spaces" in them. - -### 2024/03/11 (v23.0.5) - -- Updated python module verification. -- Removed cudnn module installation in Windows. - -### 2024/03/10 (v23.0.4) - -- Updated bitsandbytes to 0.43.0. -- Added packaging to runpod setup. - -### 2024/03/10 (v23.0.3) - -- Fixed a bug with setup. -- Enforced proper python version before running the GUI to prevent issues with execution of the GUI. - -### 2024/03/10 (v23.0.2) - -- Improved validation of the path provided by users before running training. - -### 2024/03/09 (v23.0.1) - -- Updated bitsandbytes module to 0.43.0 as it provides native Windows support. -- Minor fixes to the code. - -### 2024/03/02 (v23.0.0) - -- Used sd-scripts release [0.8.4](https://github.com/kohya-ss/sd-scripts/releases/tag/v0.8.4) post commit [fccbee27277d65a8dcbdeeb81787ed4116b92e0b](https://github.com/kohya-ss/sd-scripts/commit/fccbee27277d65a8dcbdeeb81787ed4116b92e0b). -- Major code refactoring thanks to @wkpark. This will make updating sd-scripts cleaner by keeping sd-scripts files separate from the GUI files. This will also make configuration more streamlined with fewer tabs and more accordion elements. Hope you like the new style. -- This new release is implementing a significant structure change, moving all of the sd-scripts written by kohya under a folder called sd-scripts in the root of this project. This folder is a submodule that will be populated during setup or GUI execution. diff --git a/activate.ps1 b/activate.ps1 deleted file mode 100644 index ae3888be3..000000000 --- a/activate.ps1 +++ /dev/null @@ -1 +0,0 @@ -.\venv\Scripts\activate \ No newline at end of file diff --git a/js/localization.js b/assets/js/localization.js similarity index 100% rename from js/localization.js rename to assets/js/localization.js diff --git a/js/script.js b/assets/js/script.js similarity index 100% rename from js/script.js rename to assets/js/script.js diff --git a/style.css b/assets/style.css similarity index 100% rename from style.css rename to assets/style.css diff --git a/bitsandbytes_windows/cextension.py b/bitsandbytes_windows/cextension.py deleted file mode 100644 index d38684a20..000000000 --- a/bitsandbytes_windows/cextension.py +++ /dev/null @@ -1,54 +0,0 @@ -import ctypes as ct -from pathlib import Path -from warnings import warn - -from .cuda_setup.main import evaluate_cuda_setup - - -class CUDALibrary_Singleton(object): - _instance = None - - def __init__(self): - raise RuntimeError("Call get_instance() instead") - - def initialize(self): - binary_name = evaluate_cuda_setup() - package_dir = Path(__file__).parent - binary_path = package_dir / binary_name - - if not binary_path.exists(): - print(f"CUDA SETUP: TODO: compile library for specific version: {binary_name}") - legacy_binary_name = "libbitsandbytes.so" - print(f"CUDA SETUP: Defaulting to {legacy_binary_name}...") - binary_path = package_dir / legacy_binary_name - if not binary_path.exists(): - print('CUDA SETUP: CUDA detection failed. Either CUDA driver not installed, CUDA not installed, or you have multiple conflicting CUDA libraries!') - print('CUDA SETUP: If you compiled from source, try again with `make CUDA_VERSION=DETECTED_CUDA_VERSION` for example, `make CUDA_VERSION=113`.') - raise Exception('CUDA SETUP: Setup Failed!') - # self.lib = ct.cdll.LoadLibrary(binary_path) - self.lib = ct.cdll.LoadLibrary(str(binary_path)) # $$$ - else: - print(f"CUDA SETUP: Loading binary {binary_path}...") - # self.lib = ct.cdll.LoadLibrary(binary_path) - self.lib = ct.cdll.LoadLibrary(str(binary_path)) # $$$ - - @classmethod - def get_instance(cls): - if cls._instance is None: - cls._instance = cls.__new__(cls) - cls._instance.initialize() - return cls._instance - - -lib = CUDALibrary_Singleton.get_instance().lib -try: - lib.cadam32bit_g32 - lib.get_context.restype = ct.c_void_p - lib.get_cusparse.restype = ct.c_void_p - COMPILED_WITH_CUDA = True -except AttributeError: - warn( - "The installed version of bitsandbytes was compiled without GPU support. " - "8-bit optimizers and GPU quantization are unavailable." - ) - COMPILED_WITH_CUDA = False diff --git a/bitsandbytes_windows/libbitsandbytes_cpu.dll b/bitsandbytes_windows/libbitsandbytes_cpu.dll deleted file mode 100644 index b733af475..000000000 Binary files a/bitsandbytes_windows/libbitsandbytes_cpu.dll and /dev/null differ diff --git a/bitsandbytes_windows/libbitsandbytes_cuda116.dll b/bitsandbytes_windows/libbitsandbytes_cuda116.dll deleted file mode 100644 index a999316e9..000000000 Binary files a/bitsandbytes_windows/libbitsandbytes_cuda116.dll and /dev/null differ diff --git a/bitsandbytes_windows/libbitsandbytes_cuda118.dll b/bitsandbytes_windows/libbitsandbytes_cuda118.dll deleted file mode 100644 index a54cc960b..000000000 Binary files a/bitsandbytes_windows/libbitsandbytes_cuda118.dll and /dev/null differ diff --git a/bitsandbytes_windows/main.py b/bitsandbytes_windows/main.py deleted file mode 100644 index 7e5f9c981..000000000 --- a/bitsandbytes_windows/main.py +++ /dev/null @@ -1,166 +0,0 @@ -""" -extract factors the build is dependent on: -[X] compute capability - [ ] TODO: Q - What if we have multiple GPUs of different makes? -- CUDA version -- Software: - - CPU-only: only CPU quantization functions (no optimizer, no matrix multiple) - - CuBLAS-LT: full-build 8-bit optimizer - - no CuBLAS-LT: no 8-bit matrix multiplication (`nomatmul`) - -evaluation: - - if paths faulty, return meaningful error - - else: - - determine CUDA version - - determine capabilities - - based on that set the default path -""" - -import ctypes - -from .paths import determine_cuda_runtime_lib_path - - -def check_cuda_result(cuda, result_val): - # 3. Check for CUDA errors - if result_val != 0: - error_str = ctypes.c_char_p() - cuda.cuGetErrorString(result_val, ctypes.byref(error_str)) - print(f"CUDA exception! Error code: {error_str.value.decode()}") - -def get_cuda_version(cuda, cudart_path): - # https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION - try: - cudart = ctypes.CDLL(cudart_path) - except OSError: - # TODO: shouldn't we error or at least warn here? - print(f'ERROR: libcudart.so could not be read from path: {cudart_path}!') - return None - - version = ctypes.c_int() - check_cuda_result(cuda, cudart.cudaRuntimeGetVersion(ctypes.byref(version))) - version = int(version.value) - major = version//1000 - minor = (version-(major*1000))//10 - - if major < 11: - print('CUDA SETUP: CUDA version lower than 11 are currently not supported for LLM.int8(). You will be only to use 8-bit optimizers and quantization routines!!') - - return f'{major}{minor}' - - -def get_cuda_lib_handle(): - # 1. find libcuda.so library (GPU driver) (/usr/lib) - try: - cuda = ctypes.CDLL("libcuda.so") - except OSError: - # TODO: shouldn't we error or at least warn here? - print('CUDA SETUP: WARNING! libcuda.so not found! Do you have a CUDA driver installed? If you are on a cluster, make sure you are on a CUDA machine!') - return None - check_cuda_result(cuda, cuda.cuInit(0)) - - return cuda - - -def get_compute_capabilities(cuda): - """ - 1. find libcuda.so library (GPU driver) (/usr/lib) - init_device -> init variables -> call function by reference - 2. call extern C function to determine CC - (https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE__DEPRECATED.html) - 3. Check for CUDA errors - https://stackoverflow.com/questions/14038589/what-is-the-canonical-way-to-check-for-errors-using-the-cuda-runtime-api - # bits taken from https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549 - """ - - - nGpus = ctypes.c_int() - cc_major = ctypes.c_int() - cc_minor = ctypes.c_int() - - device = ctypes.c_int() - - check_cuda_result(cuda, cuda.cuDeviceGetCount(ctypes.byref(nGpus))) - ccs = [] - for i in range(nGpus.value): - check_cuda_result(cuda, cuda.cuDeviceGet(ctypes.byref(device), i)) - ref_major = ctypes.byref(cc_major) - ref_minor = ctypes.byref(cc_minor) - # 2. call extern C function to determine CC - check_cuda_result( - cuda, cuda.cuDeviceComputeCapability(ref_major, ref_minor, device) - ) - ccs.append(f"{cc_major.value}.{cc_minor.value}") - - return ccs - - -# def get_compute_capability()-> Union[List[str, ...], None]: # FIXME: error -def get_compute_capability(cuda): - """ - Extracts the highest compute capbility from all available GPUs, as compute - capabilities are downwards compatible. If no GPUs are detected, it returns - None. - """ - ccs = get_compute_capabilities(cuda) - if ccs is not None: - # TODO: handle different compute capabilities; for now, take the max - return ccs[-1] - return None - - -def evaluate_cuda_setup(): - print('') - print('='*35 + 'BUG REPORT' + '='*35) - print('Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues') - print('For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link') - print('='*80) - return "libbitsandbytes_cuda116.dll" # $$$ - - binary_name = "libbitsandbytes_cpu.so" - #if not torch.cuda.is_available(): - #print('No GPU detected. Loading CPU library...') - #return binary_name - - cudart_path = determine_cuda_runtime_lib_path() - if cudart_path is None: - print( - "WARNING: No libcudart.so found! Install CUDA or the cudatoolkit package (anaconda)!" - ) - return binary_name - - print(f"CUDA SETUP: CUDA runtime path found: {cudart_path}") - cuda = get_cuda_lib_handle() - cc = get_compute_capability(cuda) - print(f"CUDA SETUP: Highest compute capability among GPUs detected: {cc}") - cuda_version_string = get_cuda_version(cuda, cudart_path) - - - if cc == '': - print( - "WARNING: No GPU detected! Check your CUDA paths. Processing to load CPU-only library..." - ) - return binary_name - - # 7.5 is the minimum CC vor cublaslt - has_cublaslt = cc in ["7.5", "8.0", "8.6"] - - # TODO: - # (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible) - # (2) Multiple CUDA versions installed - - # we use ls -l instead of nvcc to determine the cuda version - # since most installations will have the libcudart.so installed, but not the compiler - print(f'CUDA SETUP: Detected CUDA version {cuda_version_string}') - - def get_binary_name(): - "if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt.so" - bin_base_name = "libbitsandbytes_cuda" - if has_cublaslt: - return f"{bin_base_name}{cuda_version_string}.so" - else: - return f"{bin_base_name}{cuda_version_string}_nocublaslt.so" - - binary_name = get_binary_name() - - return binary_name diff --git a/config example.toml b/config example.toml index 5c9f1b052..69a781f90 100644 --- a/config example.toml +++ b/config example.toml @@ -2,15 +2,133 @@ # Edit the values to suit your needs # Default folders location -models_dir = "./models" # Pretrained model name or path -train_data_dir = "./data" # Image folder (containing training images subfolders) / Image folder (containing training images) -output_dir = "./outputs" # Output directory for trained model -reg_data_dir = "./data/reg" # Regularisation directory -logging_dir = "./logs" # Logging directory -config_dir = "./presets" # Load/Save Config file -log_tracker_config_dir = "./logs" # Log tracker configs directory -state_dir = "./outputs" # Resume from saved training state -vae_dir = "./models/vae" # VAEs folder path - -# Example custom folder location -# models_dir = "e:/models" # Pretrained model name or path +[model] +models_dir = "./models" # Pretrained model name or path +output_name = "new model" # Trained model output name +train_data_dir = "./data" # Image folder (containing training images subfolders) / Image folder (containing training images) +dataset_config = "./test.toml" # Dataset config file (Optional. Select the toml configuration file to use for the dataset) +training_comment = "Some training comment" # Training comment +save_model_as = "safetensors" # Save model as (ckpt, safetensors, diffusers, diffusers_safetensors) +save_precision = "bf16" # Save model precision (fp16, bf16, float) + +[folders] +output_dir = "./outputs" # Output directory for trained model +reg_data_dir = "./data/reg" # Regularisation directory +logging_dir = "./logs" # Logging directory + +[configuration] +config_dir = "./presets" # Load/Save Config file + +[accelerate_launch] +extra_accelerate_launch_args = "" # Extra accelerate launch args +gpu_ids = "" # GPU IDs +main_process_port = 0 # Main process port +mixed_precision = "fp16" # Mixed precision (fp16, bf16, fp8) +multi_gpu = false # Multi GPU +num_cpu_threads_per_process = 2 # Number of CPU threads per process +num_machines = 1 # Number of machines +num_processes = 1 # Number of processes + +[basic] +cache_latents = true # Cache latents +cache_latents_to_disk = false # Cache latents to disk +caption_extension = ".txt" # Caption extension +enable_bucket = true # Enable bucket +epoch = 1 # Epoch +learning_rate = 0.0001 # Learning rate +learning_rate_te = 0.0001 # Learning rate text encoder +learning_rate_te1 = 0.0001 # Learning rate text encoder 1 +learning_rate_te2 = 0.0001 # Learning rate text encoder 2 +lr_scheduler = "cosine" # LR Scheduler +lr_scheduler_args = "" # LR Scheduler args +lr_warmup = 0 # LR Warmup (% of total steps) +lr_scheduler_num_cycles = "" # LR Scheduler num cycles +lr_scheduler_power = "" # LR Scheduler power +max_bucket_reso = 2048 # Max bucket resolution +max_grad_norm = 1.0 # Max grad norm +max_resolution = "512,512" # Max resolution +max_train_steps = "" # Max train steps +max_train_epochs = "" # Max train epochs +min_bucket_reso = 256 # Min bucket resolution +optimizer = "AdamW8bit" # Optimizer (AdamW, AdamW8bit, Adafactor, DAdaptation, DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptAdamPreprint, DAdaptLion, DAdaptSGD, Lion, Lion8bit, PagedAdam +optimizer_args = "" # Optimizer args +save_every_n_epochs = 1 # Save every n epochs +save_every_n_steps = 1 # Save every n steps +seed = "1234" # Seed +stop_text_encoder_training = 0 # Stop text encoder training (% of total steps) +train_batch_size = 1 # Train batch size + +[advanced] +adaptive_noise_scale = 0 # Adaptive noise scale +additional_parameters = "" # Additional parameters +bucket_no_upscale = true # Don't upscale bucket resolution +bucket_reso_steps = 64 # Bucket resolution steps +caption_dropout_every_n_epochs = 0 # Caption dropout every n epochs +caption_dropout_rate = 0 # Caption dropout rate +color_aug = false # Color augmentation +clip_skip = 1 # Clip skip +debiased_estimation_loss = false # Debiased estimation loss +flip_aug = false # Flip augmentation +fp8_base = false # FP8 base training (experimental) +full_bf16 = false # Full bf16 training (experimental) +full_fp16 = false # Full fp16 training (experimental) +gradient_accumulation_steps = 1 # Gradient accumulation steps +gradient_checkpointing = false # Gradient checkpointing +ip_noise_gamma = 0 # IP noise gamma +ip_noise_gamma_random_strength = false # IP noise gamma random strength (true, false) +keep_tokens = 0 # Keep tokens +log_tracker_config_dir = "./logs" # Log tracker configs directory +log_tracker_name = "" # Log tracker name +masked_loss = false # Masked loss +max_data_loader_n_workers = "0" # Max data loader n workers (string) +max_timestep = 1000 # Max timestep +max_token_length = "150" # Max token length ("75", "150", "225") +mem_eff_attn = false # Memory efficient attention +min_snr_gamma = 0 # Min SNR gamma +min_timestep = 0 # Min timestep +multires_noise_iterations = 0 # Multires noise iterations +multires_noise_discount = 0 # Multires noise discount +no_token_padding = false # Disable token padding +noise_offset = 0 # Noise offset +noise_offset_random_strength = false # Noise offset random strength (true, false) +noise_offset_type = "Original" # Noise offset type ("Original", "Multires") +persistent_data_loader_workers = false # Persistent data loader workers +prior_loss_weight = 1.0 # Prior loss weight +random_crop = false # Random crop +save_every_n_steps = 0 # Save every n steps +save_last_n_steps = 0 # Save last n steps +save_last_n_steps_state = 0 # Save last n steps state +save_state = false # Save state +save_state_on_train_end = false # Save state on train end +scale_v_pred_loss_like_noise_pred = false # Scale v pred loss like noise pred +shuffle_caption = false # Shuffle captions +state_dir = "./outputs" # Resume from saved training state +use_wandb = false # Use wandb +vae_batch_size = 0 # VAE batch size +vae_dir = "./models/vae" # VAEs folder path +v_pred_like_loss = 0 # V pred like loss weight +wandb_api_key = "" # Wandb api key +wandb_run_name = "" # Wandb run name +weighted_captions = false # Weighted captions +xformers = "xformers" # CrossAttention (none, sdp, xformers) + +# This next section can be used to set default values for the Dataset Preparation section +# The "Destination training direcroty" field will be equal to "train_data_dir" as specified above +[dataset_preparation] +class_prompt = "class" # Class prompt +images_folder = "/some/folder/where/images/are" # Training images directory +instance_prompt = "instance" # Instance prompt +reg_images_folder = "/some/folder/where/reg/images/are" # Regularisation images directory +reg_images_repeat = 1 # Regularisation images repeat +util_regularization_images_repeat_input = 1 # Regularisation images repeat input +util_training_images_repeat_input = 40 # Training images repeat input + +[samples] +sample_every_n_steps = 0 # Sample every n steps +sample_every_n_epochs = 0 # Sample every n epochs +sample_prompts = "" # Sample prompts +sample_sampler = "euler_a" # Sampler to use for image sampling + +[sdxl] +sdxl_cache_text_encoder_outputs = false # Cache text encoder outputs +sdxl_no_half_vae = true # No half VAE diff --git a/dreambooth_gui.py b/deprecated/dreambooth_gui.py similarity index 96% rename from dreambooth_gui.py rename to deprecated/dreambooth_gui.py index 998a53b38..d37e48cec 100644 --- a/dreambooth_gui.py +++ b/deprecated/dreambooth_gui.py @@ -20,8 +20,8 @@ def UI(**kwargs): headless = kwargs.get("headless", False) log.info(f"headless: {headless}") - if os.path.exists("./style.css"): - with open(os.path.join("./style.css"), "r", encoding="utf8") as file: + if os.path.exists("./assets/style.css"): + with open(os.path.join("./assets/style.css"), "r", encoding="utf8") as file: log.info("Load CSS...") css += file.read() + "\n" diff --git a/finetune_gui.py b/deprecated/finetune_gui.py similarity index 95% rename from finetune_gui.py rename to deprecated/finetune_gui.py index 73eb81a25..c610dc670 100644 --- a/finetune_gui.py +++ b/deprecated/finetune_gui.py @@ -19,8 +19,8 @@ def UI(**kwargs): headless = kwargs.get("headless", False) log.info(f"headless: {headless}") - if os.path.exists("./style.css"): - with open(os.path.join("./style.css"), "r", encoding="utf8") as file: + if os.path.exists("./assets/style.css"): + with open(os.path.join("./assets/style.css"), "r", encoding="utf8") as file: log.info("Load CSS...") css += file.read() + "\n" diff --git a/lora_gui.py b/deprecated/lora_gui.py similarity index 96% rename from lora_gui.py rename to deprecated/lora_gui.py index dd18859fa..532bbdd14 100644 --- a/lora_gui.py +++ b/deprecated/lora_gui.py @@ -22,8 +22,8 @@ def UI(**kwargs): headless = kwargs.get("headless", False) log.info(f"headless: {headless}") - if os.path.exists("./style.css"): - with open(os.path.join("./style.css"), "r", encoding="utf8") as file: + if os.path.exists("./assets/style.css"): + with open(os.path.join("./assets/style.css"), "r", encoding="utf8") as file: log.info("Load CSS...") css += file.read() + "\n" diff --git a/textual_inversion_gui.py b/deprecated/textual_inversion_gui.py similarity index 96% rename from textual_inversion_gui.py rename to deprecated/textual_inversion_gui.py index 0ebc0f209..a78b02206 100644 --- a/textual_inversion_gui.py +++ b/deprecated/textual_inversion_gui.py @@ -20,8 +20,8 @@ def UI(**kwargs): headless = kwargs.get("headless", False) log.info(f"headless: {headless}") - if os.path.exists("./style.css"): - with open(os.path.join("./style.css"), "r", encoding="utf8") as file: + if os.path.exists("./assets/style.css"): + with open(os.path.join("./assets/style.css"), "r", encoding="utf8") as file: log.info("Load CSS...") css += file.read() + "\n" diff --git a/utilities_gui.py b/deprecated/utilities_gui.py similarity index 92% rename from utilities_gui.py rename to deprecated/utilities_gui.py index 65cc066e5..5cf8bc41c 100644 --- a/utilities_gui.py +++ b/deprecated/utilities_gui.py @@ -15,8 +15,8 @@ def UI(**kwargs): css = '' - if os.path.exists('./style.css'): - with open(os.path.join('./style.css'), 'r', encoding='utf8') as file: + if os.path.exists('./assets/style.css'): + with open(os.path.join('./assets/style.css'), 'r', encoding='utf8') as file: print('Load CSS...') css += file.read() + '\n' diff --git a/gui.sh b/gui.sh index 1d6e5547b..5a0c544cc 100755 --- a/gui.sh +++ b/gui.sh @@ -74,6 +74,8 @@ else if [ "$RUNPOD" = false ]; then if [[ "$@" == *"--use-ipex"* ]]; then REQUIREMENTS_FILE="$SCRIPT_DIR/requirements_linux_ipex.txt" + elif [[ "$@" == *"--use-rocm"* ]]; then + REQUIREMENTS_FILE="$SCRIPT_DIR/requirements_linux_rocm.txt" else REQUIREMENTS_FILE="$SCRIPT_DIR/requirements_linux.txt" fi @@ -90,9 +92,12 @@ then fi export NEOReadDebugKeys=1 export ClDeviceGlobalMemSizeAvailablePercent=100 - if [[ -z "$STARTUP_CMD" ]] && [[ -z "$DISABLE_IPEXRUN" ]] && [ -x "$(command -v ipexrun)" ] + if [[ ! -z "${IPEXRUN}" ]] && [ ${IPEXRUN}="True" ] && [ -x "$(command -v ipexrun)" ] then - STARTUP_CMD=ipexrun + if [[ -z "$STARTUP_CMD" ]] + then + STARTUP_CMD=ipexrun + fi if [[ -z "$STARTUP_CMD_ARGS" ]] then STARTUP_CMD_ARGS="--multi-task-manager taskset --memory-allocator tcmalloc" diff --git a/kohya_gui.py b/kohya_gui.py index aff1e5555..cb98ef908 100644 --- a/kohya_gui.py +++ b/kohya_gui.py @@ -2,11 +2,11 @@ import os import argparse from kohya_gui.class_gui_config import KohyaSSGUIConfig -from dreambooth_gui import dreambooth_tab -from finetune_gui import finetune_tab -from textual_inversion_gui import ti_tab +from kohya_gui.dreambooth_gui import dreambooth_tab +from kohya_gui.finetune_gui import finetune_tab +from kohya_gui.textual_inversion_gui import ti_tab from kohya_gui.utilities import utilities_tab -from lora_gui import lora_tab +from kohya_gui.lora_gui import lora_tab from kohya_gui.class_lora_tab import LoRATools from kohya_gui.custom_logging import setup_logging @@ -23,8 +23,8 @@ def UI(**kwargs): headless = kwargs.get("headless", False) log.info(f"headless: {headless}") - if os.path.exists("./style.css"): - with open(os.path.join("./style.css"), "r", encoding="utf8") as file: + if os.path.exists("./assets/style.css"): + with open(os.path.join("./assets/style.css"), "r", encoding="utf8") as file: log.debug("Load CSS...") css += file.read() + "\n" @@ -40,7 +40,7 @@ def UI(**kwargs): css=css, title=f"Kohya_ss GUI {release}", theme=gr.themes.Default() ) - config = KohyaSSGUIConfig() + config = KohyaSSGUIConfig(config_file_path=kwargs.get("config_file_path")) with interface: with gr.Tab("Dreambooth"): @@ -105,6 +105,12 @@ def UI(**kwargs): if __name__ == "__main__": # torch.cuda.set_per_process_memory_fraction(0.48) parser = argparse.ArgumentParser() + parser.add_argument( + "--config", + type=str, + default="./config.toml", + help="Path to the toml config file for interface defaults", + ) parser.add_argument( "--listen", type=str, @@ -133,10 +139,12 @@ def UI(**kwargs): ) parser.add_argument("--use-ipex", action="store_true", help="Use IPEX environment") + parser.add_argument("--use-rocm", action="store_true", help="Use ROCm environment") args = parser.parse_args() UI( + config_file_path=args.config, username=args.username, password=args.password, inbrowser=args.inbrowser, diff --git a/kohya_gui/basic_caption_gui.py b/kohya_gui/basic_caption_gui.py index 0d49fb18c..ed8d46d21 100644 --- a/kohya_gui/basic_caption_gui.py +++ b/kohya_gui/basic_caption_gui.py @@ -1,7 +1,13 @@ import gradio as gr from easygui import msgbox import subprocess -from .common_gui import get_folder_path, add_pre_postfix, find_replace, scriptdir, list_dirs +from .common_gui import ( + get_folder_path, + add_pre_postfix, + find_replace, + scriptdir, + list_dirs, +) import os import sys @@ -12,6 +18,7 @@ PYTHON = sys.executable + def caption_images( caption_text: str, images_dir: str, @@ -41,26 +48,26 @@ def caption_images( # Check if images_dir is provided if not images_dir: msgbox( - 'Image folder is missing. Please provide the directory containing the images to caption.' + "Image folder is missing. Please provide the directory containing the images to caption." ) return # Check if caption_ext is provided if not caption_ext: - msgbox('Please provide an extension for the caption files.') + msgbox("Please provide an extension for the caption files.") return # Log the captioning process if caption_text: - log.info(f'Captioning files in {images_dir} with {caption_text}...') + log.info(f"Captioning files in {images_dir} with {caption_text}...") # Build the command to run caption.py - run_cmd = fr'"{PYTHON}" "{scriptdir}/tools/caption.py"' + run_cmd = rf'"{PYTHON}" "{scriptdir}/tools/caption.py"' run_cmd += f' --caption_text="{caption_text}"' # Add optional flags to the command if overwrite: - run_cmd += f' --overwrite' + run_cmd += f" --overwrite" if caption_ext: run_cmd += f' --caption_file_ext="{caption_ext}"' @@ -71,7 +78,9 @@ def caption_images( # Set the environment variable for the Python path env = os.environ.copy() - env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/tools{os.pathsep}{env.get('PYTHONPATH', '')}" + env["PYTHONPATH"] = ( + rf"{scriptdir}{os.pathsep}{scriptdir}/tools{os.pathsep}{env.get('PYTHONPATH', '')}" + ) # Run the command based on the operating system subprocess.run(run_cmd, shell=True, env=env) @@ -102,7 +111,7 @@ def caption_images( ) # Log the end of the captioning process - log.info('Captioning done.') + log.info("Captioning done.") # Gradio UI @@ -121,7 +130,11 @@ def gradio_basic_caption_gui_tab(headless=False, default_images_dir=None): from .common_gui import create_refresh_button # Set default images directory if not provided - default_images_dir = default_images_dir if default_images_dir is not None else os.path.join(scriptdir, "data") + default_images_dir = ( + default_images_dir + if default_images_dir is not None + else os.path.join(scriptdir, "data") + ) current_images_dir = default_images_dir # Function to list directories @@ -141,26 +154,34 @@ def list_images_dirs(path): return list(list_dirs(path)) # Gradio tab for basic captioning - with gr.Tab('Basic Captioning'): + with gr.Tab("Basic Captioning"): # Markdown description gr.Markdown( - 'This utility allows you to create simple caption files for each image in a folder.' + "This utility allows you to create simple caption files for each image in a folder." ) # Group and row for image folder selection with gr.Group(), gr.Row(): # Dropdown for image folder images_dir = gr.Dropdown( - label='Image folder to caption (containing the images to caption)', + label="Image folder to caption (containing the images to caption)", choices=[""] + list_images_dirs(default_images_dir), value="", interactive=True, allow_custom_value=True, ) # Refresh button for image folder - create_refresh_button(images_dir, lambda: None, lambda: {"choices": list_images_dirs(current_images_dir)},"open_folder_small") + create_refresh_button( + images_dir, + lambda: None, + lambda: {"choices": list_images_dirs(current_images_dir)}, + "open_folder_small", + ) # Button to open folder folder_button = gr.Button( - '📂', elem_id='open_folder_small', elem_classes=["tool"], visible=(not headless) + "📂", + elem_id="open_folder_small", + elem_classes=["tool"], + visible=(not headless), ) # Event handler for button click folder_button.click( @@ -170,14 +191,14 @@ def list_images_dirs(path): ) # Textbox for caption file extension caption_ext = gr.Textbox( - label='Caption file extension', - placeholder='Extension for caption file (e.g., .caption, .txt)', - value='.txt', + label="Caption file extension", + placeholder="Extension for caption file (e.g., .caption, .txt)", + value=".txt", interactive=True, ) # Checkbox to overwrite existing captions overwrite = gr.Checkbox( - label='Overwrite existing captions in folder', + label="Overwrite existing captions in folder", interactive=True, value=False, ) @@ -185,41 +206,41 @@ def list_images_dirs(path): with gr.Row(): # Textbox for caption prefix prefix = gr.Textbox( - label='Prefix to add to caption', - placeholder='(Optional)', + label="Prefix to add to caption", + placeholder="(Optional)", interactive=True, ) # Textbox for caption text caption_text = gr.Textbox( - label='Caption text', + label="Caption text", placeholder='e.g., "by some artist". Leave empty if you only want to add a prefix or postfix.', interactive=True, lines=2, ) # Textbox for caption postfix postfix = gr.Textbox( - label='Postfix to add to caption', - placeholder='(Optional)', + label="Postfix to add to caption", + placeholder="(Optional)", interactive=True, ) # Group and row for find and replace text with gr.Group(), gr.Row(): # Textbox for find text find_text = gr.Textbox( - label='Find text', + label="Find text", placeholder='e.g., "by some artist". Leave empty if you only want to add a prefix or postfix.', interactive=True, lines=2, ) # Textbox for replace text replace_text = gr.Textbox( - label='Replacement text', + label="Replacement text", placeholder='e.g., "by some artist". Leave empty if you want to replace with nothing.', interactive=True, lines=2, ) # Button to caption images - caption_button = gr.Button('Caption images') + caption_button = gr.Button("Caption images") # Event handler for button click caption_button.click( caption_images, diff --git a/kohya_gui/blip2_caption_gui.py b/kohya_gui/blip2_caption_gui.py new file mode 100644 index 000000000..28cc6f8bb --- /dev/null +++ b/kohya_gui/blip2_caption_gui.py @@ -0,0 +1,352 @@ +from PIL import Image +from transformers import Blip2Processor, Blip2ForConditionalGeneration +import torch +import gradio as gr +import os + +from .common_gui import get_folder_path, scriptdir, list_dirs +from .custom_logging import setup_logging + +# Set up logging +log = setup_logging() + + +def load_model(): + # Set the device to GPU if available, otherwise use CPU + device = "cuda" if torch.cuda.is_available() else "cpu" + + # Initialize the BLIP2 processor + processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") + + # Initialize the BLIP2 model + model = Blip2ForConditionalGeneration.from_pretrained( + "Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16 + ) + + # Move the model to the specified device + model.to(device) + + return processor, model, device + + +def get_images_in_directory(directory_path): + """ + Returns a list of image file paths found in the provided directory path. + + Parameters: + - directory_path: A string representing the path to the directory to search for images. + + Returns: + - A list of strings, where each string is the full path to an image file found in the specified directory. + """ + import os + + # List of common image file extensions to look for + image_extensions = [".jpg", ".jpeg", ".png", ".bmp", ".gif"] + + # Generate a list of image file paths in the directory + image_files = [ + # constructs the full path to the file + os.path.join(directory_path, file) + # lists all files and directories in the given path + for file in os.listdir(directory_path) + # gets the file extension in lowercase + if os.path.splitext(file)[1].lower() in image_extensions + ] + + # Return the list of image file paths + return image_files + + +def generate_caption( + file_list, + processor, + model, + device, + caption_file_ext=".txt", + num_beams=5, + repetition_penalty=1.5, + length_penalty=1.2, + max_new_tokens=40, + min_new_tokens=20, + do_sample=True, + top_p=0.0, +): + """ + Fetches and processes each image in file_list, generates captions based on the image, and writes the generated captions to a file. + + Parameters: + - file_list: A list of file paths pointing to the images to be captioned. + - processor: The preprocessor for the BLIP2 model. + - model: The BLIP2 model to be used for generating captions. + - device: The device on which the computation is performed. + - extension: The extension for the output text files. + - num_beams: Number of beams for beam search. Default: 5. + - repetition_penalty: Penalty for repeating tokens. Default: 1.5. + - length_penalty: Penalty for sentence length. Default: 1.2. + - max_new_tokens: Maximum number of new tokens to generate. Default: 40. + - min_new_tokens: Minimum number of new tokens to generate. Default: 20. + """ + for file_path in file_list: + image = Image.open(file_path) + + inputs = processor(images=image, return_tensors="pt").to(device, torch.float16) + + if top_p == 0.0: + generated_ids = model.generate( + **inputs, + num_beams=num_beams, + repetition_penalty=repetition_penalty, + length_penalty=length_penalty, + max_new_tokens=max_new_tokens, + min_new_tokens=min_new_tokens, + ) + else: + generated_ids = model.generate( + **inputs, + do_sample=do_sample, + top_p=top_p, + max_new_tokens=max_new_tokens, + min_new_tokens=min_new_tokens, + ) + + generated_text = processor.batch_decode( + generated_ids, skip_special_tokens=True + )[0].strip() + + # Construct the output file path by replacing the original file extension with the specified extension + output_file_path = os.path.splitext(file_path)[0] + caption_file_ext + + # Write the generated text to the output file + with open(output_file_path, "w") as output_file: + output_file.write(generated_text) + + # Log the image file path with a message about the fact that the caption was generated + log.info(f"{file_path} caption was generated") + + +def caption_images_beam_search( + directory_path, + num_beams, + repetition_penalty, + length_penalty, + min_new_tokens, + max_new_tokens, + caption_file_ext, +): + """ + Captions all images in the specified directory using the provided prompt. + + Parameters: + - directory_path: A string representing the path to the directory containing the images to be captioned. + """ + log.info("BLIP2 captionning beam...") + + if not os.path.isdir(directory_path): + log.error(f"Directory {directory_path} does not exist.") + return + + processor, model, device = load_model() + image_files = get_images_in_directory(directory_path) + generate_caption( + file_list=image_files, + processor=processor, + model=model, + device=device, + num_beams=int(num_beams), + repetition_penalty=repetition_penalty, + length_penalty=length_penalty, + min_new_tokens=int(min_new_tokens), + max_new_tokens=int(max_new_tokens), + caption_file_ext=caption_file_ext, + ) + + +def caption_images_nucleus( + directory_path, + do_sample, + top_p, + min_new_tokens, + max_new_tokens, + caption_file_ext, +): + """ + Captions all images in the specified directory using the provided prompt. + + Parameters: + - directory_path: A string representing the path to the directory containing the images to be captioned. + """ + log.info("BLIP2 captionning nucleus...") + + if not os.path.isdir(directory_path): + log.error(f"Directory {directory_path} does not exist.") + return + + processor, model, device = load_model() + image_files = get_images_in_directory(directory_path) + generate_caption( + file_list=image_files, + processor=processor, + model=model, + device=device, + do_sample=do_sample, + top_p=top_p, + min_new_tokens=int(min_new_tokens), + max_new_tokens=int(max_new_tokens), + caption_file_ext=caption_file_ext, + ) + + +def gradio_blip2_caption_gui_tab(headless=False, directory_path=None): + from .common_gui import create_refresh_button + + directory_path = ( + directory_path + if directory_path is not None + else os.path.join(scriptdir, "data") + ) + current_train_dir = directory_path + + def list_train_dirs(path): + nonlocal current_train_dir + current_train_dir = path + return list(list_dirs(path)) + + with gr.Tab("BLIP2 Captioning"): + gr.Markdown( + "This utility uses BLIP2 to caption files for each image in a folder." + ) + + with gr.Group(), gr.Row(): + directory_path_dir = gr.Dropdown( + label="Image folder to caption (containing the images to caption)", + choices=[""] + list_train_dirs(directory_path), + value="", + interactive=True, + allow_custom_value=True, + ) + create_refresh_button( + directory_path_dir, + lambda: None, + lambda: {"choices": list_train_dirs(current_train_dir)}, + "open_folder_small", + ) + button_directory_path_dir_input = gr.Button( + "📂", + elem_id="open_folder_small", + elem_classes=["tool"], + visible=(not headless), + ) + button_directory_path_dir_input.click( + get_folder_path, + outputs=directory_path_dir, + show_progress=False, + ) + with gr.Group(), gr.Row(): + min_new_tokens = gr.Number( + value=20, + label="Min new tokens", + interactive=True, + step=1, + minimum=5, + maximum=300, + ) + max_new_tokens = gr.Number( + value=40, + label="Max new tokens", + interactive=True, + step=1, + minimum=5, + maximum=300, + ) + caption_file_ext = gr.Textbox( + label="Caption file extension", + placeholder="Extension for caption file (e.g., .caption, .txt)", + value=".txt", + interactive=True, + ) + + with gr.Row(): + with gr.Tab("Beam search"): + with gr.Row(): + num_beams = gr.Slider( + minimum=1, + maximum=16, + value=16, + step=1, + interactive=True, + label="Number of beams", + ) + + temperature = gr.Slider( + minimum=0.5, + maximum=1.0, + value=1.0, + step=0.1, + interactive=True, + label="Temperature", + info="used with nucleus sampling", + ) + + len_penalty = gr.Slider( + minimum=-1.0, + maximum=2.0, + value=1.0, + step=0.2, + interactive=True, + label="Length Penalty", + info="increase for longer sequence", + ) + + rep_penalty = gr.Slider( + minimum=1.0, + maximum=5.0, + value=1.5, + step=0.5, + interactive=True, + label="Repeat Penalty", + info="larger value prevents repetition", + ) + + caption_button_beam = gr.Button( + value="Caption images", interactive=True, variant="primary" + ) + caption_button_beam.click( + caption_images_beam_search, + inputs=[ + directory_path_dir, + num_beams, + rep_penalty, + len_penalty, + min_new_tokens, + max_new_tokens, + caption_file_ext, + ], + ) + with gr.Tab("Nucleus sampling"): + with gr.Row(): + do_sample = gr.Checkbox(label="Sample", value=True) + + top_p = gr.Slider( + minimum=-0, + maximum=1, + value=0.9, + step=0.1, + interactive=True, + label="Top_p", + ) + + caption_button_nucleus = gr.Button( + value="Caption images", interactive=True, variant="primary" + ) + caption_button_nucleus.click( + caption_images_nucleus, + inputs=[ + directory_path_dir, + do_sample, + top_p, + min_new_tokens, + max_new_tokens, + caption_file_ext, + ], + ) diff --git a/kohya_gui/blip_caption_gui.py b/kohya_gui/blip_caption_gui.py index 1067ca30c..dab98454c 100644 --- a/kohya_gui/blip_caption_gui.py +++ b/kohya_gui/blip_caption_gui.py @@ -74,7 +74,9 @@ def caption_images( # Set up the environment env = os.environ.copy() - env["PYTHONPATH"] = f"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + env["PYTHONPATH"] = ( + f"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + ) # Run the command in the sd-scripts folder context subprocess.run(run_cmd, shell=True, env=env, cwd=f"{scriptdir}/sd-scripts") diff --git a/kohya_gui/class_accelerate_launch.py b/kohya_gui/class_accelerate_launch.py new file mode 100644 index 000000000..9f4bd73c8 --- /dev/null +++ b/kohya_gui/class_accelerate_launch.py @@ -0,0 +1,120 @@ +import gradio as gr +import os +from .class_gui_config import KohyaSSGUIConfig + + +class AccelerateLaunch: + def __init__( + self, + config: KohyaSSGUIConfig = {}, + ) -> None: + self.config = config + + with gr.Accordion("Resource Selection", open=True): + with gr.Row(): + self.mixed_precision = gr.Dropdown( + label="Mixed precision", + choices=["no", "fp16", "bf16", "fp8"], + value=self.config.get("accelerate_launch.mixed_precision", "fp16"), + info="Whether or not to use mixed precision training.", + ) + self.num_processes = gr.Number( + label="Number of processes", + value=self.config.get("accelerate_launch.num_processes", 1), + precision=0, + minimum=1, + info="The total number of processes to be launched in parallel.", + ) + self.num_machines = gr.Number( + label="Number of machines", + value=self.config.get("accelerate_launch.num_machines", 1), + precision=0, + minimum=1, + info="The total number of machines used in this training.", + ) + self.num_cpu_threads_per_process = gr.Slider( + minimum=1, + maximum=os.cpu_count(), + step=1, + label="Number of CPU threads per core", + value=self.config.get( + "accelerate_launch.num_cpu_threads_per_process", 2 + ), + info="The number of CPU threads per process.", + ) + with gr.Accordion("Hardware Selection", open=True): + with gr.Row(): + self.multi_gpu = gr.Checkbox( + label="Multi GPU", + value=self.config.get("accelerate_launch.multi_gpu", False), + info="Whether or not this should launch a distributed GPU training.", + ) + with gr.Accordion("Distributed GPUs", open=True): + with gr.Row(): + self.gpu_ids = gr.Textbox( + label="GPU IDs", + value=self.config.get("accelerate_launch.gpu_ids", ""), + placeholder="example: 0,1", + info=" What GPUs (by id) should be used for training on this machine as a comma-separated list", + ) + self.main_process_port = gr.Number( + label="Main process port", + value=self.config.get("accelerate_launch.main_process_port", 0), + precision=1, + minimum=0, + maximum=65535, + info="The port to use to communicate with the machine of rank 0.", + ) + with gr.Row(): + self.extra_accelerate_launch_args = gr.Textbox( + label="Extra accelerate launch arguments", + value=self.config.get( + "accelerate_launch.extra_accelerate_launch_args", "" + ), + placeholder="example: --same_network --machine_rank 4", + info="List of extra parameters to pass to accelerate launch", + ) + + def run_cmd(**kwargs): + run_cmd = "" + + if "extra_accelerate_launch_args" in kwargs: + extra_accelerate_launch_args = kwargs.get("extra_accelerate_launch_args") + if extra_accelerate_launch_args != "": + run_cmd += rf" {extra_accelerate_launch_args}" + + if "gpu_ids" in kwargs: + gpu_ids = kwargs.get("gpu_ids") + if not gpu_ids == "": + run_cmd += f' --gpu_ids="{gpu_ids}"' + + if "main_process_port" in kwargs: + main_process_port = kwargs.get("main_process_port") + if main_process_port > 0: + run_cmd += f' --main_process_port="{main_process_port}"' + + if "mixed_precision" in kwargs: + run_cmd += rf' --mixed_precision="{kwargs.get("mixed_precision")}"' + + if "multi_gpu" in kwargs: + if kwargs.get("multi_gpu"): + run_cmd += " --multi_gpu" + + if "num_processes" in kwargs: + num_processes = kwargs.get("num_processes") + if int(num_processes) > 0: + run_cmd += f" --num_processes={int(num_processes)}" + + if "num_machines" in kwargs: + num_machines = kwargs.get("num_machines") + if int(num_machines) > 0: + run_cmd += f" --num_machines={int(num_machines)}" + + if "num_cpu_threads_per_process" in kwargs: + num_cpu_threads_per_process = kwargs.get("num_cpu_threads_per_process") + if int(num_cpu_threads_per_process) > 0: + run_cmd += ( + f" --num_cpu_threads_per_process={int(num_cpu_threads_per_process)}" + ) + + return run_cmd diff --git a/kohya_gui/class_advanced_training.py b/kohya_gui/class_advanced_training.py index 8b448b862..fa81b27b1 100644 --- a/kohya_gui/class_advanced_training.py +++ b/kohya_gui/class_advanced_training.py @@ -1,5 +1,4 @@ import gradio as gr -import os from typing import Tuple from .common_gui import ( get_folder_path, @@ -7,7 +6,7 @@ list_files, list_dirs, create_refresh_button, - document_symbol + document_symbol, ) @@ -47,10 +46,10 @@ def __init__( self.config = config # Determine the current directories for VAE and output, falling back to defaults if not specified. - self.current_vae_dir = self.config.get("vae_dir", "./models/vae") - self.current_state_dir = self.config.get("state_dir", "./outputs") + self.current_vae_dir = self.config.get("advanced.vae_dir", "./models/vae") + self.current_state_dir = self.config.get("advanced.state_dir", "./outputs") self.current_log_tracker_config_dir = self.config.get( - "log_tracker_config_dir", "./logs" + "advanced.log_tracker_config_dir", "./logs" ) # Define the behavior for changing noise offset type. @@ -76,19 +75,19 @@ def noise_offset_type_change( # Exclude token padding option for LoRA training type. if training_type != "lora": self.no_token_padding = gr.Checkbox( - label="No token padding", value=False + label="No token padding", value=self.config.get("advanced.no_token_padding", False) ) self.gradient_accumulation_steps = gr.Slider( label="Gradient accumulate steps", info="Number of updates steps to accumulate before performing a backward/update pass", - value="1", + value=self.config.get("advanced.gradient_accumulation_steps", 1), minimum=1, maximum=120, step=1, ) - self.weighted_captions = gr.Checkbox(label="Weighted captions", value=False) + self.weighted_captions = gr.Checkbox(label="Weighted captions", value=self.config.get("advanced.weighted_captions", False)) with gr.Group(), gr.Row(visible=not finetuning): - self.prior_loss_weight = gr.Number(label="Prior loss weight", value=1.0) + self.prior_loss_weight = gr.Number(label="Prior loss weight", value=self.config.get("advanced.prior_loss_weight", 1.0)) def list_vae_files(path): self.current_vae_dir = path if not path == "" else "." @@ -97,14 +96,14 @@ def list_vae_files(path): self.vae = gr.Dropdown( label="VAE (Optional: Path to checkpoint of vae for training)", interactive=True, - choices=[""] + list_vae_files(self.current_vae_dir), - value="", + choices=[self.config.get("advanced.vae_dir", "")] + list_vae_files(self.current_vae_dir), + value=self.config.get("advanced.vae_dir", ""), allow_custom_value=True, ) create_refresh_button( self.vae, lambda: None, - lambda: {"choices": [""] + list_vae_files(self.current_vae_dir)}, + lambda: {"choices": [self.config.get("advanced.vae_dir", "")] + list_vae_files(self.current_vae_dir)}, "open_folder_small", ) self.vae_button = gr.Button( @@ -117,7 +116,7 @@ def list_vae_files(path): ) self.vae.change( - fn=lambda path: gr.Dropdown(choices=[""] + list_vae_files(path)), + fn=lambda path: gr.Dropdown(choices=[self.config.get("advanced.vae_dir", "")] + list_vae_files(path)), inputs=self.vae, outputs=self.vae, show_progress=False, @@ -127,23 +126,24 @@ def list_vae_files(path): self.additional_parameters = gr.Textbox( label="Additional parameters", placeholder='(Optional) Use to provide additional parameters not handled by the GUI. Eg: --some_parameters "value"', + value=self.config.get("advanced.additional_parameters", ""), ) with gr.Row(): self.save_every_n_steps = gr.Number( label="Save every N steps", - value=0, + value=self.config.get("advanced.save_every_n_steps", 0), precision=0, info="(Optional) The model is saved every specified steps", ) self.save_last_n_steps = gr.Number( label="Save last N steps", - value=0, + value=self.config.get("advanced.save_last_n_steps", 0), precision=0, info="(Optional) Save only the specified number of models (old models will be deleted)", ) self.save_last_n_steps_state = gr.Number( label="Save last N steps state", - value=0, + value=self.config.get("advanced.save_last_n_steps_state", 0), precision=0, info="(Optional) Save only the specified number of states (old models will be deleted)", ) @@ -162,10 +162,10 @@ def full_options_update(full_fp16, full_bf16): ), gr.Checkbox(interactive=full_bf16_active) self.keep_tokens = gr.Slider( - label="Keep n tokens", value="0", minimum=0, maximum=32, step=1 + label="Keep n tokens", value=self.config.get("advanced.keep_tokens", 0), minimum=0, maximum=32, step=1 ) self.clip_skip = gr.Slider( - label="Clip skip", value="1", minimum=1, maximum=12, step=1 + label="Clip skip", value=self.config.get("advanced.clip_skip", 1), minimum=1, maximum=12, step=1 ) self.max_token_length = gr.Dropdown( label="Max Token Length", @@ -174,7 +174,7 @@ def full_options_update(full_fp16, full_bf16): "150", "225", ], - value="75", + value=self.config.get("advanced.max_token_length", "75"), ) with gr.Row(): @@ -182,15 +182,15 @@ def full_options_update(full_fp16, full_bf16): self.fp8_base = gr.Checkbox( label="fp8 base training (experimental)", info="U-Net and Text Encoder can be trained with fp8 (experimental)", - value=False, + value=self.config.get("advanced.fp8_base", False), ) self.full_fp16 = gr.Checkbox( label="Full fp16 training (experimental)", - value=False, + value=self.config.get("advanced.full_fp16", False), ) self.full_bf16 = gr.Checkbox( label="Full bf16 training (experimental)", - value=False, + value=self.config.get("advanced.full_bf16", False), info="Required bitsandbytes >= 0.36.0", ) @@ -207,48 +207,72 @@ def full_options_update(full_fp16, full_bf16): with gr.Row(): self.gradient_checkpointing = gr.Checkbox( - label="Gradient checkpointing", value=False + label="Gradient checkpointing", value=self.config.get("advanced.gradient_checkpointing", False) ) - self.shuffle_caption = gr.Checkbox(label="Shuffle caption", value=False) + self.shuffle_caption = gr.Checkbox(label="Shuffle caption", value=self.config.get("advanced.shuffle_caption", False)) self.persistent_data_loader_workers = gr.Checkbox( - label="Persistent data loader", value=False + label="Persistent data loader", value=self.config.get("advanced.persistent_data_loader_workers", False) ) self.mem_eff_attn = gr.Checkbox( - label="Memory efficient attention", value=False + label="Memory efficient attention", value=self.config.get("advanced.mem_eff_attn", False) ) with gr.Row(): self.xformers = gr.Dropdown( label="CrossAttention", choices=["none", "sdpa", "xformers"], - value="xformers", + value=self.config.get("advanced.xformers", "xformers"), + ) + self.color_aug = gr.Checkbox( + label="Color augmentation", + value=self.config.get("advanced.color_aug", False), + info="Enable weak color augmentation", + ) + self.flip_aug = gr.Checkbox( + label="Flip augmentation", + value=getattr(self.config, "advanced.flip_aug", False), + info="Enable horizontal flip augmentation", + ) + self.masked_loss = gr.Checkbox( + label="Masked loss", + value=self.config.get("advanced.masked_loss", False), + info="Apply mask for calculating loss. conditioning_data_dir is required for dataset", + ) + with gr.Row(): + self.scale_v_pred_loss_like_noise_pred = gr.Checkbox( + label="Scale v prediction loss", + value=self.config.get("advanced.scale_v_pred_loss_like_noise_pred", False), + info="Only for SD v2 models. By scaling the loss according to the time step, the weights of global noise prediction and local noise prediction become the same, and the improvement of details may be expected.", ) - self.color_aug = gr.Checkbox(label="Color augmentation", value=False) - self.flip_aug = gr.Checkbox(label="Flip augmentation", value=False) self.min_snr_gamma = gr.Slider( label="Min SNR gamma", - value=0, + value=self.config.get("advanced.min_snr_gamma", 0), minimum=0, maximum=20, step=1, info="Recommended value of 5 when used", ) + self.debiased_estimation_loss = gr.Checkbox( + label="Debiased Estimation loss", + value=self.config.get("advanced.debiased_estimation_loss", False), + info="Automates the processing of noise, allowing for faster model fitting, as well as balancing out color issues. Do not use if Min SNR gamma is specified.", + ) with gr.Row(): # self.sdpa = gr.Checkbox(label='Use sdpa', value=False, info='Use sdpa for CrossAttention') self.bucket_no_upscale = gr.Checkbox( - label="Don't upscale bucket resolution", value=True + label="Don't upscale bucket resolution", value=self.config.get("advanced.bucket_no_upscale", True) ) self.bucket_reso_steps = gr.Slider( label="Bucket resolution steps", - value=64, + value=self.config.get("advanced.bucket_reso_steps", 64), minimum=1, maximum=128, ) self.random_crop = gr.Checkbox( - label="Random crop instead of center crop", value=False + label="Random crop instead of center crop", value=self.config.get("advanced.random_crop", False) ) self.v_pred_like_loss = gr.Slider( label="V Pred like loss", - value=0, + value=self.config.get("advanced.v_pred_like_loss", 0), minimum=0, maximum=1, step=0.01, @@ -258,7 +282,7 @@ def full_options_update(full_fp16, full_bf16): with gr.Row(): self.min_timestep = gr.Slider( label="Min Timestep", - value=0, + value=self.config.get("advanced.min_timestep", 0), step=1, minimum=0, maximum=1000, @@ -266,7 +290,7 @@ def full_options_update(full_fp16, full_bf16): ) self.max_timestep = gr.Slider( label="Max Timestep", - value=1000, + value=self.config.get("advanced.max_timestep", 1000), step=1, minimum=0, maximum=1000, @@ -280,41 +304,61 @@ def full_options_update(full_fp16, full_bf16): "Original", "Multires", ], - value="Original", + value=self.config.get("advanced.noise_offset_type", "Original"), + scale=1, ) with gr.Row(visible=True) as self.noise_offset_original: self.noise_offset = gr.Slider( label="Noise offset", - value=0, + value=self.config.get("advanced.noise_offset", 0), minimum=0, maximum=1, step=0.01, - info='Recommended values are 0.05 - 0.15', + info="Recommended values are 0.05 - 0.15", + ) + self.noise_offset_random_strength = gr.Checkbox( + label="Noise offset random strength", + value=self.config.get("advanced.noise_offset_random_strength", False), + info="Use random strength between 0~noise_offset for noise offset", ) self.adaptive_noise_scale = gr.Slider( label="Adaptive noise scale", - value=0, + value=self.config.get("advanced.adaptive_noise_scale", 0), minimum=-1, maximum=1, step=0.001, - info="(Experimental, Optional) Since the latent is close to a normal distribution, it may be a good idea to specify a value around 1/10 the noise offset.", + info="Add `latent mean absolute value * this value` to noise_offset", ) with gr.Row(visible=False) as self.noise_offset_multires: self.multires_noise_iterations = gr.Slider( label="Multires noise iterations", - value=0, + value=self.config.get("advanced.multires_noise_iterations", 0), minimum=0, maximum=64, step=1, - info='Enable multires noise (recommended values are 6-10)', + info="Enable multires noise (recommended values are 6-10)", ) self.multires_noise_discount = gr.Slider( label="Multires noise discount", - value=0, + value=self.config.get("advanced.multires_noise_discount", 0), + minimum=0, + maximum=1, + step=0.01, + info="Recommended values are 0.8. For LoRAs with small datasets, 0.1-0.3", + ) + with gr.Row(visible=True): + self.ip_noise_gamma = gr.Slider( + label="IP noise gamma", + value=self.config.get("advanced.ip_noise_gamma", 0), minimum=0, maximum=1, step=0.01, - info='Recommended values are 0.8. For LoRAs with small datasets, 0.1-0.3', + info="enable input perturbation noise. used for regularization. recommended value: around 0.1", + ) + self.ip_noise_gamma_random_strength = gr.Checkbox( + label="IP noise gamma random strength", + value=self.config.get("advanced.ip_noise_gamma_random_strength", False), + info="Use random strength between 0~ip_noise_gamma for input perturbation noise", ) self.noise_offset_type.change( noise_offset_type_change, @@ -326,16 +370,20 @@ def full_options_update(full_fp16, full_bf16): ) with gr.Row(): self.caption_dropout_every_n_epochs = gr.Number( - label="Dropout caption every n epochs", value=0 + label="Dropout caption every n epochs", value=self.config.get("advanced.caption_dropout_every_n_epochs", 0), ) self.caption_dropout_rate = gr.Slider( - label="Rate of caption dropout", value=0, minimum=0, maximum=1 + label="Rate of caption dropout", value=self.config.get("advanced.caption_dropout_rate", 0), minimum=0, maximum=1 ) self.vae_batch_size = gr.Slider( - label="VAE batch size", minimum=0, maximum=32, value=0, step=1 + label="VAE batch size", minimum=0, maximum=32, value=self.config.get("advanced.vae_batch_size", 0), step=1 ) with gr.Group(), gr.Row(): - self.save_state = gr.Checkbox(label="Save training state", value=False) + self.save_state = gr.Checkbox(label="Save training state", value=self.config.get("advanced.save_state", False)) + + self.save_state_on_train_end = gr.Checkbox( + label="Save training state at end of training", value=self.config.get("advanced.save_state_on_train_end", False) + ) def list_state_dirs(path): self.current_state_dir = path if not path == "" else "." @@ -343,15 +391,15 @@ def list_state_dirs(path): self.resume = gr.Dropdown( label='Resume from saved training state (path to "last-state" state folder)', - choices=[""] + list_state_dirs(self.current_state_dir), - value="", + choices=[self.config.get("advanced.state_dir", "")] + list_state_dirs(self.current_state_dir), + value=self.config.get("advanced.state_dir", ""), interactive=True, allow_custom_value=True, ) create_refresh_button( self.resume, lambda: None, - lambda: {"choices": [""] + list_state_dirs(self.current_state_dir)}, + lambda: {"choices": [self.config.get("advanced.state_dir", "")] + list_state_dirs(self.current_state_dir)}, "open_folder_small", ) self.resume_button = gr.Button( @@ -363,7 +411,7 @@ def list_state_dirs(path): show_progress=False, ) self.resume.change( - fn=lambda path: gr.Dropdown(choices=[""] + list_state_dirs(path)), + fn=lambda path: gr.Dropdown(choices=[self.config.get("advanced.state_dir", "")] + list_state_dirs(path)), inputs=self.resume, outputs=self.resume, show_progress=False, @@ -371,34 +419,23 @@ def list_state_dirs(path): self.max_data_loader_n_workers = gr.Textbox( label="Max num workers for DataLoader", placeholder="(Optional) Override number of epoch. Default: 8", - value="0", - ) - with gr.Row(): - self.num_processes = gr.Number( - label="Number of processes", value=1, precision=0, minimum=1 - ) - self.num_machines = gr.Number( - label="Number of machines", value=1, precision=0, minimum=1 - ) - self.multi_gpu = gr.Checkbox(label="Multi GPU", value=False) - self.gpu_ids = gr.Textbox( - label="GPU IDs", value="", placeholder="example: 0,1" + value=self.config.get("advanced.max_data_loader_n_workers", "0"), ) with gr.Row(): self.use_wandb = gr.Checkbox( label="WANDB Logging", - value=False, + value=self.config.get("advanced.use_wandb", False), info="If unchecked, tensorboard will be used as the default for logging.", ) self.wandb_api_key = gr.Textbox( label="WANDB API Key", - value="", + value=self.config.get("advanced.wandb_api_key", ""), placeholder="(Optional)", info="Users can obtain and/or generate an api key in the their user settings on the website: https://wandb.ai/login", ) self.wandb_run_name = gr.Textbox( label="WANDB run name", - value="", + value=self.config.get("advanced.wandb_run_name", ""), placeholder="(Optional)", info="The name of the specific wandb session", ) @@ -410,15 +447,15 @@ def list_log_tracker_config_files(path): self.log_tracker_name = gr.Textbox( label="Log tracker name", - value="", + value=self.config.get("advanced.log_tracker_name", ""), placeholder="(Optional)", info="Name of tracker to use for logging, default is script-specific default name", ) self.log_tracker_config = gr.Dropdown( label="Log tracker config", - choices=[""] + choices=[self.config.get("log_tracker_config_dir", "")] + list_log_tracker_config_files(self.current_log_tracker_config_dir), - value="", + value=self.config.get("log_tracker_config_dir", ""), info="Path to tracker config file to use for logging", interactive=True, allow_custom_value=True, @@ -427,7 +464,7 @@ def list_log_tracker_config_files(path): self.log_tracker_config, lambda: None, lambda: { - "choices": [""] + "choices": [self.config.get("log_tracker_config_dir", "")] + list_log_tracker_config_files(self.current_log_tracker_config_dir) }, "open_folder_small", @@ -442,20 +479,9 @@ def list_log_tracker_config_files(path): ) self.log_tracker_config.change( fn=lambda path: gr.Dropdown( - choices=[""] + list_log_tracker_config_files(path) + choices=[self.config.get("log_tracker_config_dir", "")] + list_log_tracker_config_files(path) ), inputs=self.log_tracker_config, outputs=self.log_tracker_config, show_progress=False, ) - with gr.Row(): - self.scale_v_pred_loss_like_noise_pred = gr.Checkbox( - label="Scale v prediction loss", - value=False, - info="Only for SD v2 models. By scaling the loss according to the time step, the weights of global noise prediction and local noise prediction become the same, and the improvement of details may be expected.", - ) - self.debiased_estimation_loss = gr.Checkbox( - label="Debiased Estimation loss", - value=False, - info="Automates the processing of noise, allowing for faster model fitting, as well as balancing out color issues", - ) diff --git a/kohya_gui/class_basic_training.py b/kohya_gui/class_basic_training.py index 8eb3e8493..0df6ddf0a 100644 --- a/kohya_gui/class_basic_training.py +++ b/kohya_gui/class_basic_training.py @@ -1,5 +1,4 @@ import gradio as gr -import os from typing import Tuple @@ -25,6 +24,7 @@ def __init__( lr_warmup_value: str = "0", finetuning: bool = False, dreambooth: bool = False, + config: dict = {}, ) -> None: """ Initializes the BasicTraining object with the given parameters. @@ -43,6 +43,7 @@ def __init__( self.lr_warmup_value = lr_warmup_value self.finetuning = finetuning self.dreambooth = dreambooth + self.config = config # Initialize the UI components self.initialize_ui_components() @@ -76,28 +77,31 @@ def init_training_controls(self) -> None: with gr.Row(): # Initialize the train batch size slider self.train_batch_size = gr.Slider( - minimum=1, maximum=64, label="Train batch size", value=1, step=1 + minimum=1, maximum=64, label="Train batch size", value=1, step=self.config.get("basic.train_batch_size", 1), ) # Initialize the epoch number input - self.epoch = gr.Number(label="Epoch", value=1, precision=0) + self.epoch = gr.Number(label="Epoch", value=self.config.get("basic.epoch", 1), precision=0) # Initialize the maximum train epochs input self.max_train_epochs = gr.Textbox( label="Max train epoch", placeholder="(Optional) Enforce # epochs", + value=self.config.get("basic.max_train_epochs", ""), ) # Initialize the maximum train steps input self.max_train_steps = gr.Textbox( label="Max train steps", placeholder="(Optional) Enforce # steps", + value=self.config.get("basic.max_train_steps", ""), ) # Initialize the save every N epochs input self.save_every_n_epochs = gr.Number( - label="Save every N epochs", value=1, precision=0 + label="Save every N epochs", value=self.config.get("basic.save_every_n_epochs", 1), precision=0 ) # Initialize the caption extension input self.caption_extension = gr.Textbox( label="Caption Extension", placeholder="(Optional) default: .caption", + value=self.config.get("basic.caption_extension", ""), ) def init_precision_and_resources_controls(self) -> None: @@ -105,25 +109,13 @@ def init_precision_and_resources_controls(self) -> None: Initializes the precision and resources controls for the model. """ with gr.Row(): - # Initialize the mixed precision dropdown - self.mixed_precision = gr.Dropdown( - label="Mixed precision", choices=["no", "fp16", "bf16"], value="fp16" - ) - # Initialize the number of CPU threads per core slider - self.num_cpu_threads_per_process = gr.Slider( - minimum=1, - maximum=os.cpu_count(), - step=1, - label="Number of CPU threads per core", - value=2, - ) # Initialize the seed textbox - self.seed = gr.Textbox(label="Seed", placeholder="(Optional) eg:1234") + self.seed = gr.Textbox(label="Seed", placeholder="(Optional) eg:1234", value=self.config.get("basic.seed", "")) # Initialize the cache latents checkbox - self.cache_latents = gr.Checkbox(label="Cache latents", value=True) + self.cache_latents = gr.Checkbox(label="Cache latents", value=self.config.get("basic.cache_latents", True)) # Initialize the cache latents to disk checkbox self.cache_latents_to_disk = gr.Checkbox( - label="Cache latents to disk", value=False + label="Cache latents to disk", value=self.config.get("basic.cache_latents_to_disk", False) ) def init_lr_and_optimizer_controls(self) -> None: @@ -143,7 +135,7 @@ def init_lr_and_optimizer_controls(self) -> None: "linear", "polynomial", ], - value=self.lr_scheduler_value, + value=self.config.get("basic.lr_scheduler", self.lr_scheduler_value), ) # Initialize the optimizer dropdown self.optimizer = gr.Dropdown( @@ -169,7 +161,7 @@ def init_lr_and_optimizer_controls(self) -> None: "SGDNesterov", "SGDNesterov8bit", ], - value="AdamW8bit", + value=self.config.get("basic.optimizer", "AdamW8bit"), interactive=True, ) @@ -180,19 +172,21 @@ def init_grad_and_lr_controls(self) -> None: with gr.Row(): # Initialize the maximum gradient norm slider self.max_grad_norm = gr.Slider( - label="Max grad norm", value=1.0, minimum=0.0, maximum=1.0 + label="Max grad norm", value=self.config.get("basic.max_grad_norm", 1.0), minimum=0.0, maximum=1.0 ) # Initialize the learning rate scheduler extra arguments textbox self.lr_scheduler_args = gr.Textbox( label="LR scheduler extra arguments", lines=2, placeholder='(Optional) eg: "milestones=[1,10,30,50]" "gamma=0.1"', + value=self.config.get("basic.lr_scheduler_args", ""), ) # Initialize the optimizer extra arguments textbox self.optimizer_args = gr.Textbox( label="Optimizer extra arguments", lines=2, placeholder="(Optional) eg: relative_step=True scale_parameter=True warmup_init=True", + value=self.config.get("basic.optimizer_args", ""), ) def init_learning_rate_controls(self) -> None: @@ -209,7 +203,7 @@ def init_learning_rate_controls(self) -> None: # Initialize the learning rate number input self.learning_rate = gr.Number( label=lr_label, - value=self.learning_rate_value, + value=self.config.get("basic.learning_rate", self.learning_rate_value), minimum=0, maximum=1, info="Set to 0 to not train the Unet", @@ -217,7 +211,7 @@ def init_learning_rate_controls(self) -> None: # Initialize the learning rate TE number input self.learning_rate_te = gr.Number( label="Learning rate TE", - value=self.learning_rate_value, + value=self.config.get("basic.learning_rate_te", self.learning_rate_value), visible=self.finetuning or self.dreambooth, minimum=0, maximum=1, @@ -226,7 +220,7 @@ def init_learning_rate_controls(self) -> None: # Initialize the learning rate TE1 number input self.learning_rate_te1 = gr.Number( label="Learning rate TE1", - value=self.learning_rate_value, + value=self.config.get("basic.learning_rate_te1", self.learning_rate_value), visible=False, minimum=0, maximum=1, @@ -235,7 +229,7 @@ def init_learning_rate_controls(self) -> None: # Initialize the learning rate TE2 number input self.learning_rate_te2 = gr.Number( label="Learning rate TE2", - value=self.learning_rate_value, + value=self.config.get("basic.learning_rate_te2", self.learning_rate_value), visible=False, minimum=0, maximum=1, @@ -244,7 +238,7 @@ def init_learning_rate_controls(self) -> None: # Initialize the learning rate warmup slider self.lr_warmup = gr.Slider( label="LR warmup (% of total steps)", - value=self.lr_warmup_value, + value=self.config.get("basic.lr_warmup", self.lr_warmup_value), minimum=0, maximum=100, step=1, @@ -259,11 +253,13 @@ def init_scheduler_controls(self) -> None: self.lr_scheduler_num_cycles = gr.Textbox( label="LR # cycles", placeholder="(Optional) For Cosine with restart and polynomial only", + value=self.config.get("basic.lr_scheduler_num_cycles", ""), ) # Initialize the learning rate scheduler power textbox self.lr_scheduler_power = gr.Textbox( label="LR power", placeholder="(Optional) For Cosine with restart and polynomial only", + value=self.config.get("basic.lr_scheduler_power", ""), ) def init_resolution_and_bucket_controls(self) -> None: @@ -273,22 +269,22 @@ def init_resolution_and_bucket_controls(self) -> None: with gr.Row(visible=not self.finetuning): # Initialize the maximum resolution textbox self.max_resolution = gr.Textbox( - label="Max resolution", value="512,512", placeholder="512,512" + label="Max resolution", value=self.config.get("basic.max_resolution", "512,512"), placeholder="512,512" ) # Initialize the stop text encoder training slider self.stop_text_encoder_training = gr.Slider( minimum=-1, maximum=100, - value=0, + value=self.config.get("basic.stop_text_encoder_training", 0), step=1, label="Stop TE (% of total steps)", ) # Initialize the enable buckets checkbox - self.enable_bucket = gr.Checkbox(label="Enable buckets", value=True) + self.enable_bucket = gr.Checkbox(label="Enable buckets", value=self.config.get("basic.enable_bucket", True)) # Initialize the minimum bucket resolution slider self.min_bucket_reso = gr.Slider( label="Minimum bucket resolution", - value=256, + value=self.config.get("basic.min_bucket_reso", 256), minimum=64, maximum=4096, step=64, @@ -297,7 +293,7 @@ def init_resolution_and_bucket_controls(self) -> None: # Initialize the maximum bucket resolution slider self.max_bucket_reso = gr.Slider( label="Maximum bucket resolution", - value=2048, + value=self.config.get("basic.max_bucket_reso", 2048), minimum=64, maximum=4096, step=64, diff --git a/kohya_gui/class_command_executor.py b/kohya_gui/class_command_executor.py index d64ad92a9..ccd15a713 100644 --- a/kohya_gui/class_command_executor.py +++ b/kohya_gui/class_command_executor.py @@ -5,6 +5,7 @@ # Set up logging log = setup_logging() + class CommandExecutor: """ A class to execute and manage commands. @@ -43,7 +44,9 @@ def kill_command(self): log.info("The running process has been terminated.") except psutil.NoSuchProcess: # Explicitly handle the case where the process does not exist - log.info("The process does not exist. It might have terminated before the kill command was issued.") + log.info( + "The process does not exist. It might have terminated before the kill command was issued." + ) except Exception as e: # General exception handling for any other errors log.info(f"Error when terminating process: {e}") diff --git a/kohya_gui/class_configuration_file.py b/kohya_gui/class_configuration_file.py index a99caa1dd..444e2c4ca 100644 --- a/kohya_gui/class_configuration_file.py +++ b/kohya_gui/class_configuration_file.py @@ -12,7 +12,9 @@ class ConfigurationFile: A class to handle configuration file operations in the GUI. """ - def __init__(self, headless: bool = False, config_dir: str = None, config:dict = {}): + def __init__( + self, headless: bool = False, config_dir: str = None, config: dict = {} + ): """ Initialize the ConfigurationFile class. @@ -22,11 +24,13 @@ def __init__(self, headless: bool = False, config_dir: str = None, config:dict = """ self.headless = headless - + self.config = config # Sets the directory for storing configuration files, defaults to a 'presets' folder within the script directory. - self.current_config_dir = self.config.get('config_dir', os.path.join(scriptdir, "presets")) + self.current_config_dir = self.config.get( + "config_dir", os.path.join(scriptdir, "presets") + ) # Initialize the GUI components for configuration. self.create_config_gui() @@ -56,8 +60,8 @@ def create_config_gui(self) -> None: # Dropdown for selecting or entering the name of a configuration file. self.config_file_name = gr.Dropdown( label="Load/Save Config file", - choices=[""] + self.list_config_dir(self.current_config_dir), - value="", + choices=[self.config.get("config_dir", "")] + self.list_config_dir(self.current_config_dir), + value=self.config.get("config_dir", ""), interactive=True, allow_custom_value=True, ) diff --git a/kohya_gui/class_folders.py b/kohya_gui/class_folders.py index 5a34d0afc..a0467fb51 100644 --- a/kohya_gui/class_folders.py +++ b/kohya_gui/class_folders.py @@ -1,12 +1,16 @@ import gradio as gr import os -from .common_gui import get_folder_path, scriptdir, list_dirs, list_files, create_refresh_button +from .common_gui import get_folder_path, scriptdir, list_dirs, create_refresh_button + class Folders: """ A class to handle folder operations in the GUI. """ - def __init__(self, finetune: bool = False, headless: bool = False, config:dict = {}): + + def __init__( + self, finetune: bool = False, headless: bool = False, config: dict = {} + ): """ Initialize the Folders class. @@ -21,9 +25,15 @@ def __init__(self, finetune: bool = False, headless: bool = False, config:dict = self.config = config # Set default directories if not provided - self.current_output_dir = self.config.get('output_dir', os.path.join(scriptdir, "outputs")) - self.current_logging_dir = self.config.get('logging_dir', os.path.join(scriptdir, "logs")) - self.current_reg_data_dir = self.config.get('reg_data_dir', os.path.join(scriptdir, "reg")) + self.current_output_dir = self.config.get( + "output_dir", os.path.join(scriptdir, "outputs") + ) + self.current_logging_dir = self.config.get( + "logging_dir", os.path.join(scriptdir, "logs") + ) + self.current_reg_data_dir = self.config.get( + "reg_data_dir", os.path.join(scriptdir, "reg") + ) # Create directories if they don't exist self.create_directory_if_not_exists(self.current_output_dir) @@ -39,10 +49,13 @@ def create_directory_if_not_exists(self, directory: str) -> None: Parameters: - directory (str): The directory to create. """ - if directory is not None and directory.strip() != "" and not os.path.exists(directory): + if ( + directory is not None + and directory.strip() != "" + and not os.path.exists(directory) + ): os.makedirs(directory, exist_ok=True) - def list_output_dirs(self, path: str) -> list: """ List directories in the output directory. @@ -90,16 +103,26 @@ def create_folders_gui(self) -> None: # Output directory dropdown self.output_dir = gr.Dropdown( label="Output directory for trained model", - choices=[""] + self.list_output_dirs(self.current_output_dir), - value="", + choices=[self.config.get("folders.output_dir", "")] + self.list_output_dirs(self.current_output_dir), + value=self.config.get("folders.output_dir", ""), interactive=True, allow_custom_value=True, ) # Refresh button for output directory - create_refresh_button(self.output_dir, lambda: None, lambda: {"choices": [""] + self.list_output_dirs(self.current_output_dir)}, "open_folder_small") + create_refresh_button( + self.output_dir, + lambda: None, + lambda: { + "choices": [""] + self.list_output_dirs(self.current_output_dir) + }, + "open_folder_small", + ) # Output directory button self.output_dir_folder = gr.Button( - '📂', elem_id='open_folder_small', elem_classes=["tool"], visible=(not self.headless) + "📂", + elem_id="open_folder_small", + elem_classes=["tool"], + visible=(not self.headless), ) # Output directory button click event self.output_dir_folder.click( @@ -110,17 +133,31 @@ def create_folders_gui(self) -> None: # Regularisation directory dropdown self.reg_data_dir = gr.Dropdown( - label='Regularisation directory (Optional. containing regularisation images)' if not self.finetune else 'Train config directory (Optional. where config files will be saved)', - choices=[""] + self.list_reg_data_dirs(self.current_reg_data_dir), - value="", + label=( + "Regularisation directory (Optional. containing regularisation images)" + if not self.finetune + else "Train config directory (Optional. where config files will be saved)" + ), + choices=[self.config.get("folders.reg_data_dir", "")] + self.list_reg_data_dirs(self.current_reg_data_dir), + value=self.config.get("folders.reg_data_dir", ""), interactive=True, allow_custom_value=True, ) # Refresh button for regularisation directory - create_refresh_button(self.reg_data_dir, lambda: None, lambda: {"choices": [""] + self.list_reg_data_dirs(self.current_reg_data_dir)}, "open_folder_small") + create_refresh_button( + self.reg_data_dir, + lambda: None, + lambda: { + "choices": [""] + self.list_reg_data_dirs(self.current_reg_data_dir) + }, + "open_folder_small", + ) # Regularisation directory button self.reg_data_dir_folder = gr.Button( - '📂', elem_id='open_folder_small', elem_classes=["tool"], visible=(not self.headless) + "📂", + elem_id="open_folder_small", + elem_classes=["tool"], + visible=(not self.headless), ) # Regularisation directory button click event self.reg_data_dir_folder.click( @@ -131,17 +168,27 @@ def create_folders_gui(self) -> None: with gr.Row(): # Logging directory dropdown self.logging_dir = gr.Dropdown( - label='Logging directory (Optional. to enable logging and output Tensorboard log)', - choices=[""] + self.list_logging_dirs(self.current_logging_dir), - value="", + label="Logging directory (Optional. to enable logging and output Tensorboard log)", + choices=[self.config.get("folders.logging_dir", "")] + self.list_logging_dirs(self.current_logging_dir), + value=self.config.get("folders.logging_dir", ""), interactive=True, allow_custom_value=True, ) # Refresh button for logging directory - create_refresh_button(self.logging_dir, lambda: None, lambda: {"choices": [""] + self.list_logging_dirs(self.current_logging_dir)}, "open_folder_small") + create_refresh_button( + self.logging_dir, + lambda: None, + lambda: { + "choices": [""] + self.list_logging_dirs(self.current_logging_dir) + }, + "open_folder_small", + ) # Logging directory button self.logging_dir_folder = gr.Button( - '📂', elem_id='open_folder_small', elem_classes=["tool"], visible=(not self.headless) + "📂", + elem_id="open_folder_small", + elem_classes=["tool"], + visible=(not self.headless), ) # Logging directory button click event self.logging_dir_folder.click( @@ -159,14 +206,18 @@ def create_folders_gui(self) -> None: ) # Change event for regularisation directory dropdown self.reg_data_dir.change( - fn=lambda path: gr.Dropdown(choices=[""] + self.list_reg_data_dirs(path)), + fn=lambda path: gr.Dropdown( + choices=[""] + self.list_reg_data_dirs(path) + ), inputs=self.reg_data_dir, outputs=self.reg_data_dir, show_progress=False, ) # Change event for logging directory dropdown self.logging_dir.change( - fn=lambda path: gr.Dropdown(choices=[""] + self.list_logging_dirs(path)), + fn=lambda path: gr.Dropdown( + choices=[""] + self.list_logging_dirs(path) + ), inputs=self.logging_dir, outputs=self.logging_dir, show_progress=False, diff --git a/kohya_gui/class_gui_config.py b/kohya_gui/class_gui_config.py index 7f73ac3f7..3624631e6 100644 --- a/kohya_gui/class_gui_config.py +++ b/kohya_gui/class_gui_config.py @@ -5,18 +5,19 @@ # Set up logging log = setup_logging() + class KohyaSSGUIConfig: """ A class to handle the configuration for the Kohya SS GUI. """ - def __init__(self): + def __init__(self, config_file_path: str = "./config.toml"): """ Initialize the KohyaSSGUIConfig class. """ - self.config = self.load_config() + self.config = self.load_config(config_file_path=config_file_path) - def load_config(self) -> dict: + def load_config(self, config_file_path: str = "./config.toml") -> dict: """ Loads the Kohya SS GUI configuration from a TOML file. @@ -25,16 +26,18 @@ def load_config(self) -> dict: """ try: # Attempt to load the TOML configuration file from the specified directory. - config = toml.load(f"{scriptdir}/config.toml") - log.debug(f"Loaded configuration from {scriptdir}/config.toml") + config = toml.load(f"{config_file_path}") + log.debug(f"Loaded configuration from {config_file_path}") except FileNotFoundError: # If the config file is not found, initialize `config` as an empty dictionary to handle missing configurations gracefully. config = {} - log.debug(f"No configuration file found at {scriptdir}/config.toml. Initializing empty configuration.") + log.debug( + f"No configuration file found at {config_file_path}. Initializing empty configuration." + ) return config - def save_config(self, config: dict): + def save_config(self, config: dict, config_file_path: str = "./config.toml"): """ Saves the Kohya SS GUI configuration to a TOML file. @@ -42,7 +45,7 @@ def save_config(self, config: dict): - config (dict): The configuration data to save. """ # Write the configuration data to the TOML file - with open(f"{scriptdir}/config.toml", "w") as f: + with open(f"{config_file_path}", "w") as f: toml.dump(config, f) def get(self, key: str, default=None): @@ -66,7 +69,9 @@ def get(self, key: str, default=None): log.debug(k) # If the key is not found in the current data, return the default value if k not in data: - log.debug(f"Key '{key}' not found in configuration. Returning default value.") + log.debug( + f"Key '{key}' not found in configuration. Returning default value." + ) return default # Update `data` to the value associated with the current key diff --git a/kohya_gui/class_lora_tab.py b/kohya_gui/class_lora_tab.py index b8db73d42..efeaf952e 100644 --- a/kohya_gui/class_lora_tab.py +++ b/kohya_gui/class_lora_tab.py @@ -14,9 +14,7 @@ class LoRATools: def __init__(self, headless: bool = False): self.headless = headless - gr.Markdown( - 'This section provide various LoRA tools...' - ) + gr.Markdown("This section provide various LoRA tools...") gradio_extract_dylora_tab(headless=headless) gradio_convert_lcm_tab(headless=headless) gradio_extract_lora_tab(headless=headless) diff --git a/kohya_gui/class_sample_images.py b/kohya_gui/class_sample_images.py index 4f22c6320..7a1fa0c2d 100644 --- a/kohya_gui/class_sample_images.py +++ b/kohya_gui/class_sample_images.py @@ -1,16 +1,16 @@ import os import gradio as gr -from easygui import msgbox from .custom_logging import setup_logging +from .class_gui_config import KohyaSSGUIConfig # Set up logging log = setup_logging() -folder_symbol = '\U0001f4c2' # 📂 -refresh_symbol = '\U0001f504' # 🔄 -save_style_symbol = '\U0001f4be' # 💾 -document_symbol = '\U0001F4C4' # 📄 +folder_symbol = "\U0001f4c2" # 📂 +refresh_symbol = "\U0001f504" # 🔄 +save_style_symbol = "\U0001f4be" # 💾 +document_symbol = "\U0001F4C4" # 📄 ### @@ -38,10 +38,10 @@ def run_cmd_sample( Returns: str: The command string for sampling images. """ - output_dir = os.path.join(output_dir, 'sample') + output_dir = os.path.join(output_dir, "sample") os.makedirs(output_dir, exist_ok=True) - run_cmd = '' + run_cmd = "" if sample_every_n_epochs is None: sample_every_n_epochs = 0 @@ -53,19 +53,19 @@ def run_cmd_sample( return run_cmd # Create the prompt file and get its path - sample_prompts_path = os.path.join(output_dir, 'prompt.txt') + sample_prompts_path = os.path.join(output_dir, "prompt.txt") - with open(sample_prompts_path, 'w') as f: + with open(sample_prompts_path, "w") as f: f.write(sample_prompts) - run_cmd += f' --sample_sampler={sample_sampler}' + run_cmd += f" --sample_sampler={sample_sampler}" run_cmd += f' --sample_prompts="{sample_prompts_path}"' if sample_every_n_epochs != 0: - run_cmd += f' --sample_every_n_epochs={sample_every_n_epochs}' + run_cmd += f" --sample_every_n_epochs={sample_every_n_epochs}" if sample_every_n_steps != 0: - run_cmd += f' --sample_every_n_steps={sample_every_n_steps}' + run_cmd += f" --sample_every_n_steps={sample_every_n_steps}" return run_cmd @@ -77,10 +77,13 @@ class SampleImages: def __init__( self, + config: KohyaSSGUIConfig = {}, ): """ Initializes the SampleImages class. """ + self.config = config + self.initialize_accordion() def initialize_accordion(self): @@ -89,45 +92,46 @@ def initialize_accordion(self): """ with gr.Row(): self.sample_every_n_steps = gr.Number( - label='Sample every n steps', - value=0, + label="Sample every n steps", + value=self.config.get("samples.sample_every_n_steps", 0), precision=0, interactive=True, ) self.sample_every_n_epochs = gr.Number( - label='Sample every n epochs', - value=0, + label="Sample every n epochs", + value=self.config.get("samples.sample_every_n_epochs", 0), precision=0, interactive=True, ) self.sample_sampler = gr.Dropdown( - label='Sample sampler', + label="Sample sampler", choices=[ - 'ddim', - 'pndm', - 'lms', - 'euler', - 'euler_a', - 'heun', - 'dpm_2', - 'dpm_2_a', - 'dpmsolver', - 'dpmsolver++', - 'dpmsingle', - 'k_lms', - 'k_euler', - 'k_euler_a', - 'k_dpm_2', - 'k_dpm_2_a', + "ddim", + "pndm", + "lms", + "euler", + "euler_a", + "heun", + "dpm_2", + "dpm_2_a", + "dpmsolver", + "dpmsolver++", + "dpmsingle", + "k_lms", + "k_euler", + "k_euler_a", + "k_dpm_2", + "k_dpm_2_a", ], - value='euler_a', + value=self.config.get("samples.sample_sampler", "euler_a"), interactive=True, ) with gr.Row(): self.sample_prompts = gr.Textbox( lines=5, - label='Sample prompts', + label="Sample prompts", interactive=True, - placeholder='masterpiece, best quality, 1girl, in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28', - info='Enter one sample prompt per line to generate multiple samples per cycle. Optional specifiers include: --w (width), --h (height), --d (seed), --l (cfg scale), --s (sampler steps) and --n (negative prompt). To modify sample prompts during training, edit the prompt.txt file in the samples directory.', + placeholder="masterpiece, best quality, 1girl, in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28", + info="Enter one sample prompt per line to generate multiple samples per cycle. Optional specifiers include: --w (width), --h (height), --d (seed), --l (cfg scale), --s (sampler steps) and --n (negative prompt). To modify sample prompts during training, edit the prompt.txt file in the samples directory.", + value=self.config.get("samples.sample_prompts", ""), ) diff --git a/kohya_gui/class_sdxl_parameters.py b/kohya_gui/class_sdxl_parameters.py index a95ff13af..b0098d2a3 100644 --- a/kohya_gui/class_sdxl_parameters.py +++ b/kohya_gui/class_sdxl_parameters.py @@ -1,32 +1,34 @@ import gradio as gr +from .class_gui_config import KohyaSSGUIConfig class SDXLParameters: def __init__( self, sdxl_checkbox: gr.Checkbox, show_sdxl_cache_text_encoder_outputs: bool = True, + config: KohyaSSGUIConfig = {}, ): self.sdxl_checkbox = sdxl_checkbox - self.show_sdxl_cache_text_encoder_outputs = ( - show_sdxl_cache_text_encoder_outputs - ) + self.show_sdxl_cache_text_encoder_outputs = show_sdxl_cache_text_encoder_outputs + self.config = config + self.initialize_accordion() def initialize_accordion(self): with gr.Accordion( - visible=False, open=True, label='SDXL Specific Parameters' + visible=False, open=True, label="SDXL Specific Parameters" ) as self.sdxl_row: with gr.Row(): self.sdxl_cache_text_encoder_outputs = gr.Checkbox( - label='Cache text encoder outputs', - info='Cache the outputs of the text encoders. This option is useful to reduce the GPU memory usage. This option cannot be used with options for shuffling or dropping the captions.', - value=False, + label="Cache text encoder outputs", + info="Cache the outputs of the text encoders. This option is useful to reduce the GPU memory usage. This option cannot be used with options for shuffling or dropping the captions.", + value=self.config.get("sdxl.sdxl_cache_text_encoder_outputs", False), visible=self.show_sdxl_cache_text_encoder_outputs, ) self.sdxl_no_half_vae = gr.Checkbox( - label='No half VAE', - info='Disable the half-precision (mixed-precision) VAE. VAE for SDXL seems to produce NaNs in some cases. This option is useful to avoid the NaNs.', - value=True, + label="No half VAE", + info="Disable the half-precision (mixed-precision) VAE. VAE for SDXL seems to produce NaNs in some cases. This option is useful to avoid the NaNs.", + value=self.config.get("sdxl.sdxl_no_half_vae", False), ) self.sdxl_checkbox.change( diff --git a/kohya_gui/class_source_model.py b/kohya_gui/class_source_model.py index 0c92905b9..a390e42fd 100644 --- a/kohya_gui/class_source_model.py +++ b/kohya_gui/class_source_model.py @@ -10,6 +10,7 @@ list_files, create_refresh_button, ) +from .class_gui_config import KohyaSSGUIConfig folder_symbol = "\U0001f4c2" # 📂 refresh_symbol = "\U0001f504" # 🔄 @@ -47,7 +48,7 @@ def __init__( ], headless=False, finetuning=False, - config: dict = {}, + config: KohyaSSGUIConfig = {}, ): self.headless = headless self.save_model_as_choices = save_model_as_choices @@ -56,13 +57,14 @@ def __init__( # Set default directories if not provided self.current_models_dir = self.config.get( - "models_dir", os.path.join(scriptdir, "models") + "model.models_dir", os.path.join(scriptdir, "models") ) self.current_train_data_dir = self.config.get( - "train_data_dir", os.path.join(scriptdir, "data") + "model.train_data_dir", os.path.join(scriptdir, "data") + ) + self.current_dataset_config_dir = self.config.get( + "model.dataset_config", os.path.join(scriptdir, "dataset_config") ) - self.current_dataset_config_dir = self.config.get('dataset_config_dir', os.path.join(scriptdir, "dataset_config")) - model_checkpoints = list( list_files( @@ -80,8 +82,8 @@ def list_models(path): def list_train_data_dirs(path): self.current_train_data_dir = path if not path == "" else "." - return list(list_dirs(path)) - + return list(list_dirs(self.current_train_data_dir)) + def list_dataset_config_dirs(path: str) -> list: """ List directories and toml files in the dataset_config directory. @@ -94,185 +96,210 @@ def list_dataset_config_dirs(path: str) -> list: """ current_dataset_config_dir = path if not path == "" else "." # Lists all .json files in the current configuration directory, used for populating dropdown choices. - return list(list_files(current_dataset_config_dir, exts=[".toml"], all=True)) + return list( + list_files(current_dataset_config_dir, exts=[".toml"], all=True) + ) + with gr.Accordion("Model", open=True): + with gr.Column(), gr.Group(): + # Define the input elements + with gr.Row(): + with gr.Column(), gr.Row(): + self.model_list = gr.Textbox(visible=False, value="") + self.pretrained_model_name_or_path = gr.Dropdown( + label="Pretrained model name or path", + choices=default_models + model_checkpoints, + value=self.config.get("model.models_dir", "runwayml/stable-diffusion-v1-5"), + allow_custom_value=True, + visible=True, + min_width=100, + ) + create_refresh_button( + self.pretrained_model_name_or_path, + lambda: None, + lambda: {"choices": list_models(self.current_models_dir)}, + "open_folder_small", + ) - with gr.Column(), gr.Group(): - # Define the input elements - with gr.Row(): - with gr.Column(), gr.Row(): - self.model_list = gr.Textbox(visible=False, value="") - self.pretrained_model_name_or_path = gr.Dropdown( - label="Pretrained model name or path", - choices=default_models + model_checkpoints, - value="runwayml/stable-diffusion-v1-5", - allow_custom_value=True, - visible=True, - min_width=100, - ) - create_refresh_button( - self.pretrained_model_name_or_path, - lambda: None, - lambda: {"choices": list_models(self.current_models_dir)}, - "open_folder_small", - ) + self.pretrained_model_name_or_path_file = gr.Button( + document_symbol, + elem_id="open_folder_small", + elem_classes=["tool"], + visible=(not headless), + ) + self.pretrained_model_name_or_path_file.click( + get_file_path, + inputs=self.pretrained_model_name_or_path, + outputs=self.pretrained_model_name_or_path, + show_progress=False, + ) + self.pretrained_model_name_or_path_folder = gr.Button( + folder_symbol, + elem_id="open_folder_small", + elem_classes=["tool"], + visible=(not headless), + ) + self.pretrained_model_name_or_path_folder.click( + get_folder_path, + inputs=self.pretrained_model_name_or_path, + outputs=self.pretrained_model_name_or_path, + show_progress=False, + ) - self.pretrained_model_name_or_path_file = gr.Button( - document_symbol, - elem_id="open_folder_small", - elem_classes=["tool"], - visible=(not headless), - ) - self.pretrained_model_name_or_path_file.click( - get_file_path, - inputs=self.pretrained_model_name_or_path, - outputs=self.pretrained_model_name_or_path, - show_progress=False, - ) - self.pretrained_model_name_or_path_folder = gr.Button( - folder_symbol, - elem_id="open_folder_small", - elem_classes=["tool"], - visible=(not headless), - ) - self.pretrained_model_name_or_path_folder.click( - get_folder_path, - inputs=self.pretrained_model_name_or_path, - outputs=self.pretrained_model_name_or_path, - show_progress=False, - ) - - with gr.Column(), gr.Row(): - self.output_name = gr.Textbox( - label="Trained Model output name", - placeholder="(Name of the model to output)", - value="last", - interactive=True, - ) - with gr.Row(): - with gr.Column(), gr.Row(): - self.train_data_dir = gr.Dropdown( - label=( - "Image folder (containing training images subfolders)" - if not finetuning - else "Image folder (containing training images)" - ), - choices=[""] - + list_train_data_dirs(self.current_train_data_dir), - value="", - interactive=True, - allow_custom_value=True, - ) - create_refresh_button( - self.train_data_dir, - lambda: None, - lambda: { - "choices": [""] - + list_train_data_dirs(self.current_train_data_dir) - }, - "open_folder_small", - ) - self.train_data_dir_folder = gr.Button( - "📂", - elem_id="open_folder_small", - elem_classes=["tool"], - visible=(not self.headless), - ) - self.train_data_dir_folder.click( - get_folder_path, - outputs=self.train_data_dir, - show_progress=False, - ) - with gr.Column(), gr.Row(): - # Toml directory dropdown - self.dataset_config = gr.Dropdown( - label='Dataset config file (Optional. Select the toml configuration file to use for the dataset)', - choices=[""] + list_dataset_config_dirs(self.current_dataset_config_dir), - value="", + with gr.Column(), gr.Row(): + self.output_name = gr.Textbox( + label="Trained Model output name", + placeholder="(Name of the model to output)", + value=self.config.get("model.output_name", "last"), + interactive=True, + ) + with gr.Row(): + with gr.Column(), gr.Row(): + self.train_data_dir = gr.Dropdown( + label=( + "Image folder (containing training images subfolders)" + if not finetuning + else "Image folder (containing training images)" + ), + choices=[""] + + list_train_data_dirs(self.current_train_data_dir), + value=self.config.get("model.train_data_dir", ""), + interactive=True, + allow_custom_value=True, + ) + create_refresh_button( + self.train_data_dir, + lambda: None, + lambda: { + "choices": [""] + + list_train_data_dirs(self.current_train_data_dir) + }, + "open_folder_small", + ) + self.train_data_dir_folder = gr.Button( + "📂", + elem_id="open_folder_small", + elem_classes=["tool"], + visible=(not self.headless), + ) + self.train_data_dir_folder.click( + get_folder_path, + outputs=self.train_data_dir, + show_progress=False, + ) + with gr.Column(), gr.Row(): + # Toml directory dropdown + self.dataset_config = gr.Dropdown( + label="Dataset config file (Optional. Select the toml configuration file to use for the dataset)", + choices=[self.config.get("model.dataset_config", "")] + + list_dataset_config_dirs(self.current_dataset_config_dir), + value=self.config.get("model.dataset_config", ""), + interactive=True, + allow_custom_value=True, + ) + # Refresh button for dataset_config directory + create_refresh_button( + self.dataset_config, + lambda: None, + lambda: { + "choices": [""] + + list_dataset_config_dirs( + self.current_dataset_config_dir + ) + }, + "open_folder_small", + ) + # Toml directory button + self.dataset_config_folder = gr.Button( + document_symbol, + elem_id="open_folder_small", + elem_classes=["tool"], + visible=(not self.headless), + ) + + # Toml directory button click event + self.dataset_config_folder.click( + get_file_path, + inputs=[ + self.dataset_config, + gr.Textbox(value="*.toml", visible=False), + gr.Textbox(value="Dataset config types", visible=False), + ], + outputs=self.dataset_config, + show_progress=False, + ) + # Change event for dataset_config directory dropdown + self.dataset_config.change( + fn=lambda path: gr.Dropdown( + choices=[""] + list_dataset_config_dirs(path) + ), + inputs=self.dataset_config, + outputs=self.dataset_config, + show_progress=False, + ) + + with gr.Row(): + with gr.Column(): + with gr.Row(): + self.v2 = gr.Checkbox( + label="v2", value=False, visible=False, min_width=60 + ) + self.v_parameterization = gr.Checkbox( + label="v_parameterization", + value=False, + visible=False, + min_width=130, + ) + self.sdxl_checkbox = gr.Checkbox( + label="SDXL", + value=False, + visible=False, + min_width=60, + ) + with gr.Column(): + gr.Box(visible=False) + + with gr.Row(): + self.training_comment = gr.Textbox( + label="Training comment", + placeholder="(Optional) Add training comment to be included in metadata", interactive=True, - allow_custom_value=True, - ) - # Refresh button for dataset_config directory - create_refresh_button(self.dataset_config, lambda: None, lambda: {"choices": [""] + list_dataset_config_dirs(self.current_dataset_config_dir)}, "open_folder_small") - # Toml directory button - self.dataset_config_folder = gr.Button( - document_symbol, elem_id='open_folder_small', elem_classes=["tool"], visible=(not self.headless) + value=self.config.get("model.training_comment", ""), ) - - # Toml directory button click event - self.dataset_config_folder.click( - get_file_path, - inputs=[self.dataset_config, gr.Textbox(value='*.toml', visible=False), gr.Textbox(value='Dataset config types', visible=False)], - outputs=self.dataset_config, - show_progress=False, + + with gr.Row(): + self.save_model_as = gr.Radio( + save_model_as_choices, + label="Save trained model as", + value=self.config.get("model.save_model_as", "safetensors"), ) - # Change event for dataset_config directory dropdown - self.dataset_config.change( - fn=lambda path: gr.Dropdown(choices=[""] + list_dataset_config_dirs(path)), - inputs=self.dataset_config, - outputs=self.dataset_config, - show_progress=False, + self.save_precision = gr.Radio( + save_precision_choices, + label="Save precision", + value=self.config.get("model.save_precision", "fp16"), ) - with gr.Row(): - with gr.Column(): - with gr.Row(): - self.v2 = gr.Checkbox( - label="v2", value=False, visible=False, min_width=60 - ) - self.v_parameterization = gr.Checkbox( - label="v_parameterization", - value=False, - visible=False, - min_width=130, - ) - self.sdxl_checkbox = gr.Checkbox( - label="SDXL", - value=False, - visible=False, - min_width=60, - ) - with gr.Column(): - gr.Box(visible=False) - - with gr.Row(): - self.training_comment = gr.Textbox( - label="Training comment", - placeholder="(Optional) Add training comment to be included in metadata", - interactive=True, + self.pretrained_model_name_or_path.change( + fn=lambda path: set_pretrained_model_name_or_path_input( + path, refresh_method=list_models + ), + inputs=[ + self.pretrained_model_name_or_path, + ], + outputs=[ + self.pretrained_model_name_or_path, + self.v2, + self.v_parameterization, + self.sdxl_checkbox, + ], + show_progress=False, ) - with gr.Row(): - self.save_model_as = gr.Radio( - save_model_as_choices, - label="Save trained model as", - value="safetensors", - ) - self.save_precision = gr.Radio( - save_precision_choices, - label="Save precision", - value="fp16", + self.train_data_dir.change( + fn=lambda path: gr.Dropdown( + choices=[""] + list_train_data_dirs(path) + ), + inputs=self.train_data_dir, + outputs=self.train_data_dir, + show_progress=False, ) - - self.pretrained_model_name_or_path.change( - fn=lambda path: set_pretrained_model_name_or_path_input( - path, refresh_method=list_models - ), - inputs=[ - self.pretrained_model_name_or_path, - ], - outputs=[ - self.pretrained_model_name_or_path, - self.v2, - self.v_parameterization, - self.sdxl_checkbox, - ], - show_progress=False, - ) - - self.train_data_dir.change( - fn=lambda path: gr.Dropdown(choices=[""] + list_train_data_dirs(path)), - inputs=self.train_data_dir, - outputs=self.train_data_dir, - show_progress=False, - ) diff --git a/kohya_gui/common_gui.py b/kohya_gui/common_gui.py index b2dee8c9b..6c922eb37 100644 --- a/kohya_gui/common_gui.py +++ b/kohya_gui/common_gui.py @@ -6,9 +6,9 @@ import os import re import gradio as gr -import shutil import sys import json +import math # Set up logging log = setup_logging() @@ -54,6 +54,7 @@ ENV_EXCLUSION = ["COLAB_GPU", "RUNPOD_POD_ID"] + def calculate_max_train_steps( total_steps: int, train_batch_size: int, @@ -71,6 +72,7 @@ def calculate_max_train_steps( ) ) + def check_if_model_exist( output_name: str, output_dir: str, save_model_as: str, headless: bool = False ) -> bool: @@ -788,60 +790,6 @@ def color_aug_changed(color_aug): return gr.Checkbox(interactive=True) -def save_inference_file( - output_dir: str, - v2: bool, - v_parameterization: bool, - output_name: str, -) -> None: - """ - Save inference file to the specified output directory. - - Args: - output_dir (str): Path to the output directory. - v2 (bool): Flag indicating whether to use v2 inference. - v_parameterization (bool): Flag indicating whether to use v parameterization. - output_name (str): Name of the output file. - """ - try: - # List all files in the directory - files = os.listdir(output_dir) - except Exception as e: - log.error(f"Error listing directory contents: {e}") - return # Early return on failure - - # Iterate over the list of files - for file in files: - # Check if the file starts with the value of output_name - if file.startswith(output_name): - # Check if it is a file or a directory - file_path = os.path.join(output_dir, file) - if os.path.isfile(file_path): - # Split the file name and extension - file_name, ext = os.path.splitext(file) - - # Determine the source file path based on the v2 and v_parameterization flags - source_file_path = ( - rf"{scriptdir}/v2_inference/v2-inference-v.yaml" - if v2 and v_parameterization - else rf"{scriptdir}/v2_inference/v2-inference.yaml" - ) - - # Copy the source file to the current file, with a .yaml extension - try: - log.info( - f"Saving {source_file_path} as {output_dir}/{file_name}.yaml" - ) - shutil.copy( - source_file_path, - f"{output_dir}/{file_name}.yaml", - ) - except Exception as e: - log.error( - f"Error copying file to {output_dir}/{file_name}.yaml: {e}" - ) - - def set_pretrained_model_name_or_path_input( pretrained_model_name_or_path, refresh_method=None ): @@ -1151,6 +1099,14 @@ def run_cmd_advanced_training(**kwargs): if kwargs.get("gradient_checkpointing"): run_cmd += " --gradient_checkpointing" + if kwargs.get("ip_noise_gamma"): + if float(kwargs["ip_noise_gamma"]) > 0: + run_cmd += f' --ip_noise_gamma={kwargs["ip_noise_gamma"]}' + + if kwargs.get("ip_noise_gamma_random_strength"): + if kwargs["ip_noise_gamma_random_strength"]: + run_cmd += f" --ip_noise_gamma_random_strength" + if "keep_tokens" in kwargs and int(kwargs["keep_tokens"]) > 0: run_cmd += f' --keep_tokens="{int(kwargs["keep_tokens"])}"' @@ -1224,230 +1180,249 @@ def run_cmd_advanced_training(**kwargs): else: run_cmd += f' --lr_warmup_steps="{lr_warmup_steps}"' - gpu_ids = kwargs.get("gpu_ids") - if gpu_ids: - run_cmd += f' --gpu_ids="{gpu_ids}"' - - max_data_loader_n_workers = kwargs.get("max_data_loader_n_workers") - if max_data_loader_n_workers and not max_data_loader_n_workers == "": - run_cmd += f' --max_data_loader_n_workers="{max_data_loader_n_workers}"' - - max_grad_norm = kwargs.get("max_grad_norm") - if max_grad_norm and max_grad_norm != "": - run_cmd += f' --max_grad_norm="{max_grad_norm}"' - - max_resolution = kwargs.get("max_resolution") - if max_resolution: - run_cmd += f' --resolution="{max_resolution}"' - - max_timestep = kwargs.get("max_timestep") - if max_timestep and int(max_timestep) < 1000: - run_cmd += f" --max_timestep={int(max_timestep)}" - - max_token_length = kwargs.get("max_token_length") - if max_token_length and int(max_token_length) > 75: - run_cmd += f" --max_token_length={int(max_token_length)}" - - max_train_epochs = kwargs.get("max_train_epochs") - if max_train_epochs and not max_train_epochs == "": - run_cmd += f" --max_train_epochs={max_train_epochs}" - - max_train_steps = kwargs.get("max_train_steps") - if max_train_steps: - run_cmd += f' --max_train_steps="{max_train_steps}"' - - mem_eff_attn = kwargs.get("mem_eff_attn") - if mem_eff_attn: - run_cmd += " --mem_eff_attn" - - min_snr_gamma = kwargs.get("min_snr_gamma") - if min_snr_gamma and int(min_snr_gamma) >= 1: - run_cmd += f" --min_snr_gamma={int(min_snr_gamma)}" - - min_timestep = kwargs.get("min_timestep") - if min_timestep and int(min_timestep) > 0: - run_cmd += f" --min_timestep={int(min_timestep)}" - - mixed_precision = kwargs.get("mixed_precision") - if mixed_precision: - run_cmd += f' --mixed_precision="{mixed_precision}"' - - multi_gpu = kwargs.get("multi_gpu") - if multi_gpu: - run_cmd += " --multi_gpu" - - network_alpha = kwargs.get("network_alpha") - if network_alpha: - run_cmd += f' --network_alpha="{network_alpha}"' - - network_args = kwargs.get("network_args") - if network_args and len(network_args): - run_cmd += f" --network_args{network_args}" - - network_dim = kwargs.get("network_dim") - if network_dim: - run_cmd += f" --network_dim={network_dim}" - - network_dropout = kwargs.get("network_dropout") - if network_dropout and network_dropout > 0.0: - run_cmd += f" --network_dropout={network_dropout}" - - network_module = kwargs.get("network_module") - if network_module: - run_cmd += f" --network_module={network_module}" - - network_train_text_encoder_only = kwargs.get("network_train_text_encoder_only") - if network_train_text_encoder_only: - run_cmd += " --network_train_text_encoder_only" - - network_train_unet_only = kwargs.get("network_train_unet_only") - if network_train_unet_only: - run_cmd += " --network_train_unet_only" - - no_half_vae = kwargs.get("no_half_vae") - if no_half_vae: - run_cmd += " --no_half_vae" - - no_token_padding = kwargs.get("no_token_padding") - if no_token_padding: - run_cmd += " --no_token_padding" + if "masked_loss" in kwargs: + if kwargs.get("masked_loss"): # Test if the value is true as it could be false + run_cmd += " --masked_loss" + + if "max_data_loader_n_workers" in kwargs: + max_data_loader_n_workers = kwargs.get("max_data_loader_n_workers") + if not max_data_loader_n_workers == "": + run_cmd += f' --max_data_loader_n_workers="{max_data_loader_n_workers}"' + + if "max_grad_norm" in kwargs: + max_grad_norm = kwargs.get("max_grad_norm") + if max_grad_norm != "": + run_cmd += f' --max_grad_norm="{max_grad_norm}"' + + if "max_resolution" in kwargs: + run_cmd += rf' --resolution="{kwargs.get("max_resolution")}"' + + if "max_timestep" in kwargs: + max_timestep = kwargs.get("max_timestep") + if int(max_timestep) < 1000: + run_cmd += f" --max_timestep={int(max_timestep)}" + + if "max_token_length" in kwargs: + max_token_length = kwargs.get("max_token_length") + if int(max_token_length) > 75: + run_cmd += f" --max_token_length={int(max_token_length)}" + + if "max_train_epochs" in kwargs: + max_train_epochs = kwargs.get("max_train_epochs") + if not max_train_epochs == "": + run_cmd += f" --max_train_epochs={max_train_epochs}" + + if "max_train_steps" in kwargs: + max_train_steps = kwargs.get("max_train_steps") + if not max_train_steps == "": + run_cmd += f' --max_train_steps="{max_train_steps}"' + + if "mem_eff_attn" in kwargs: + if kwargs.get("mem_eff_attn"): # Test if the value is true as it could be false + run_cmd += " --mem_eff_attn" + + if "min_snr_gamma" in kwargs: + min_snr_gamma = kwargs.get("min_snr_gamma") + if int(min_snr_gamma) >= 1: + run_cmd += f" --min_snr_gamma={int(min_snr_gamma)}" + + if "min_timestep" in kwargs: + min_timestep = kwargs.get("min_timestep") + if int(min_timestep) > -1: + run_cmd += f" --min_timestep={int(min_timestep)}" + + if "mixed_precision" in kwargs: + run_cmd += rf' --mixed_precision="{kwargs.get("mixed_precision")}"' + + if "network_alpha" in kwargs: + run_cmd += rf' --network_alpha="{kwargs.get("network_alpha")}"' + + if "network_args" in kwargs: + network_args = kwargs.get("network_args") + if network_args != "": + run_cmd += f" --network_args{network_args}" + + if "network_dim" in kwargs: + run_cmd += rf' --network_dim={kwargs.get("network_dim")}' + + if "network_dropout" in kwargs: + network_dropout = kwargs.get("network_dropout") + if network_dropout > 0.0: + run_cmd += f" --network_dropout={network_dropout}" + + if "network_module" in kwargs: + network_module = kwargs.get("network_module") + if network_module != "": + run_cmd += f" --network_module={network_module}" + + if "network_train_text_encoder_only" in kwargs: + if kwargs.get("network_train_text_encoder_only"): + run_cmd += " --network_train_text_encoder_only" + + if "network_train_unet_only" in kwargs: + if kwargs.get("network_train_unet_only"): + run_cmd += " --network_train_unet_only" + + if "no_half_vae" in kwargs: + if kwargs.get("no_half_vae"): # Test if the value is true as it could be false + run_cmd += " --no_half_vae" + + if "no_token_padding" in kwargs: + if kwargs.get( + "no_token_padding" + ): # Test if the value is true as it could be false + run_cmd += " --no_token_padding" if "noise_offset_type" in kwargs: noise_offset_type = kwargs["noise_offset_type"] - if kwargs["noise_offset_type"] == "Original": - noise_offset = float(kwargs.get("noise_offset", 0)) - if noise_offset: - run_cmd += f" --noise_offset={noise_offset}" + if noise_offset_type == "Original": + if "noise_offset" in kwargs: + noise_offset = float(kwargs.get("noise_offset", 0)) + if noise_offset: + run_cmd += f" --noise_offset={float(noise_offset)}" - adaptive_noise_scale = float(kwargs.get("adaptive_noise_scale", 0)) - if adaptive_noise_scale != 0 and noise_offset > 0: - run_cmd += f" --adaptive_noise_scale={adaptive_noise_scale}" + if "adaptive_noise_scale" in kwargs: + adaptive_noise_scale = float(kwargs.get("adaptive_noise_scale", 0)) + if adaptive_noise_scale != 0 and noise_offset > 0: + run_cmd += f" --adaptive_noise_scale={adaptive_noise_scale}" + if "noise_offset_random_strength" in kwargs: + if kwargs.get("noise_offset_random_strength"): + run_cmd += f" --noise_offset_random_strength" elif noise_offset_type == "Multires": - multires_noise_iterations = int(kwargs.get("multires_noise_iterations", 0)) - if multires_noise_iterations > 0: - run_cmd += f' --multires_noise_iterations="{multires_noise_iterations}"' - - multires_noise_discount = float(kwargs.get("multires_noise_discount", 0)) - if multires_noise_discount > 0: - run_cmd += f' --multires_noise_discount="{multires_noise_discount}"' - - num_machines = kwargs.get("num_machines") - if num_machines and int(num_machines) > 1: - run_cmd += f" --num_machines={int(num_machines)}" - - num_processes = kwargs.get("num_processes") - if num_processes and int(num_processes) > 1: - run_cmd += f" --num_processes={int(num_processes)}" + if "multires_noise_iterations" in kwargs: + multires_noise_iterations = int( + kwargs.get("multires_noise_iterations", 0) + ) + if multires_noise_iterations > 0: + run_cmd += ( + f' --multires_noise_iterations="{multires_noise_iterations}"' + ) - num_cpu_threads_per_process = kwargs.get("num_cpu_threads_per_process") - if num_cpu_threads_per_process and int(num_cpu_threads_per_process) > 1: - run_cmd += f" --num_cpu_threads_per_process={int(num_cpu_threads_per_process)}" + if "multires_noise_discount" in kwargs: + multires_noise_discount = float( + kwargs.get("multires_noise_discount", 0) + ) + if multires_noise_discount > 0: + run_cmd += f' --multires_noise_discount="{multires_noise_discount}"' - optimizer_args = kwargs.get("optimizer_args") - if optimizer_args and optimizer_args != "": - run_cmd += f" --optimizer_args {optimizer_args}" + if "optimizer_args" in kwargs: + optimizer_args = kwargs.get("optimizer_args") + if optimizer_args != "": + run_cmd += f" --optimizer_args {optimizer_args}" - optimizer_type = kwargs.get("optimizer") - if optimizer_type: - run_cmd += f' --optimizer_type="{optimizer_type}"' + if "optimizer" in kwargs: + run_cmd += rf' --optimizer_type="{kwargs.get("optimizer")}"' - output_dir = kwargs.get("output_dir") - if output_dir: + if "output_dir" in kwargs: + output_dir = kwargs.get("output_dir") if output_dir.startswith('"') and output_dir.endswith('"'): output_dir = output_dir[1:-1] if os.path.exists(output_dir): run_cmd += rf' --output_dir="{output_dir}"' - output_name = kwargs.get("output_name") - if output_name and not output_name == "": - run_cmd += f' --output_name="{output_name}"' - - persistent_data_loader_workers = kwargs.get("persistent_data_loader_workers") - if persistent_data_loader_workers: - run_cmd += " --persistent_data_loader_workers" - - pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path") - if pretrained_model_name_or_path: - run_cmd += ( - rf' --pretrained_model_name_or_path="{pretrained_model_name_or_path}"' - ) - - prior_loss_weight = kwargs.get("prior_loss_weight") - if prior_loss_weight and not float(prior_loss_weight) == 1.0: - run_cmd += f" --prior_loss_weight={prior_loss_weight}" - - random_crop = kwargs.get("random_crop") - if random_crop: - run_cmd += " --random_crop" - - reg_data_dir = kwargs.get("reg_data_dir") - if reg_data_dir and len(reg_data_dir): - if reg_data_dir.startswith('"') and reg_data_dir.endswith('"'): - reg_data_dir = reg_data_dir[1:-1] - if os.path.isdir(reg_data_dir): - run_cmd += rf' --reg_data_dir="{reg_data_dir}"' - - resume = kwargs.get("resume") - if resume: - run_cmd += f' --resume="{resume}"' - - save_every_n_epochs = kwargs.get("save_every_n_epochs") - if save_every_n_epochs: - run_cmd += f' --save_every_n_epochs="{int(save_every_n_epochs)}"' - - save_every_n_steps = kwargs.get("save_every_n_steps") - if save_every_n_steps and int(save_every_n_steps) > 0: - run_cmd += f' --save_every_n_steps="{int(save_every_n_steps)}"' - - save_last_n_steps = kwargs.get("save_last_n_steps") - if save_last_n_steps and int(save_last_n_steps) > 0: - run_cmd += f' --save_last_n_steps="{int(save_last_n_steps)}"' - - save_last_n_steps_state = kwargs.get("save_last_n_steps_state") - if save_last_n_steps_state and int(save_last_n_steps_state) > 0: - run_cmd += f' --save_last_n_steps_state="{int(save_last_n_steps_state)}"' - - save_model_as = kwargs.get("save_model_as") - if save_model_as and not save_model_as == "same as source model": - run_cmd += f" --save_model_as={save_model_as}" - - save_precision = kwargs.get("save_precision") - if save_precision: - run_cmd += f' --save_precision="{save_precision}"' - - save_state = kwargs.get("save_state") - if save_state: - run_cmd += " --save_state" - - scale_v_pred_loss_like_noise_pred = kwargs.get("scale_v_pred_loss_like_noise_pred") - if scale_v_pred_loss_like_noise_pred: - run_cmd += " --scale_v_pred_loss_like_noise_pred" - - scale_weight_norms = kwargs.get("scale_weight_norms") - if scale_weight_norms and scale_weight_norms > 0.0: - run_cmd += f' --scale_weight_norms="{scale_weight_norms}"' - - seed = kwargs.get("seed") - if seed and seed != "": - run_cmd += f' --seed="{seed}"' - - shuffle_caption = kwargs.get("shuffle_caption") - if shuffle_caption: - run_cmd += " --shuffle_caption" - - stop_text_encoder_training = kwargs.get("stop_text_encoder_training") - if stop_text_encoder_training and stop_text_encoder_training > 0: - run_cmd += f' --stop_text_encoder_training="{stop_text_encoder_training}"' - - text_encoder_lr = kwargs.get("text_encoder_lr") - if text_encoder_lr and (float(text_encoder_lr) > 0): - run_cmd += f" --text_encoder_lr={text_encoder_lr}" - - train_batch_size = kwargs.get("train_batch_size") - if train_batch_size: - run_cmd += f' --train_batch_size="{train_batch_size}"' + if "output_name" in kwargs: + output_name = kwargs.get("output_name") + if not output_name == "": + run_cmd += f' --output_name="{output_name}"' + + if "persistent_data_loader_workers" in kwargs: + if kwargs.get("persistent_data_loader_workers"): + run_cmd += " --persistent_data_loader_workers" + + if "pretrained_model_name_or_path" in kwargs: + run_cmd += rf' --pretrained_model_name_or_path="{kwargs.get("pretrained_model_name_or_path")}"' + + if "prior_loss_weight" in kwargs: + prior_loss_weight = kwargs.get("prior_loss_weight") + if not float(prior_loss_weight) == 1.0: + run_cmd += f" --prior_loss_weight={prior_loss_weight}" + + if "random_crop" in kwargs: + random_crop = kwargs.get("random_crop") + if random_crop: + run_cmd += " --random_crop" + + if "reg_data_dir" in kwargs: + reg_data_dir = kwargs.get("reg_data_dir") + if len(reg_data_dir): + if reg_data_dir.startswith('"') and reg_data_dir.endswith('"'): + reg_data_dir = reg_data_dir[1:-1] + if os.path.isdir(reg_data_dir): + run_cmd += rf' --reg_data_dir="{reg_data_dir}"' + + if "resume" in kwargs: + resume = kwargs.get("resume") + if len(resume): + run_cmd += f' --resume="{resume}"' + + if "save_every_n_epochs" in kwargs: + save_every_n_epochs = kwargs.get("save_every_n_epochs") + if int(save_every_n_epochs) > 0: + run_cmd += f' --save_every_n_epochs="{int(save_every_n_epochs)}"' + + if "save_every_n_steps" in kwargs: + save_every_n_steps = kwargs.get("save_every_n_steps") + if int(save_every_n_steps) > 0: + run_cmd += f' --save_every_n_steps="{int(save_every_n_steps)}"' + + if "save_last_n_steps" in kwargs: + save_last_n_steps = kwargs.get("save_last_n_steps") + if int(save_last_n_steps) > 0: + run_cmd += f' --save_last_n_steps="{int(save_last_n_steps)}"' + + if "save_last_n_steps_state" in kwargs: + save_last_n_steps_state = kwargs.get("save_last_n_steps_state") + if int(save_last_n_steps_state) > 0: + run_cmd += f' --save_last_n_steps_state="{int(save_last_n_steps_state)}"' + + if "save_model_as" in kwargs: + save_model_as = kwargs.get("save_model_as") + if save_model_as != "same as source model": + run_cmd += f" --save_model_as={save_model_as}" + + if "save_precision" in kwargs: + run_cmd += rf' --save_precision="{kwargs.get("save_precision")}"' + + if "save_state" in kwargs: + if kwargs.get("save_state"): + run_cmd += " --save_state" + + if "save_state_on_train_end" in kwargs: + if kwargs.get("save_state_on_train_end"): + run_cmd += " --save_state_on_train_end" + + if "scale_v_pred_loss_like_noise_pred" in kwargs: + if kwargs.get("scale_v_pred_loss_like_noise_pred"): + run_cmd += " --scale_v_pred_loss_like_noise_pred" + + if "scale_weight_norms" in kwargs: + scale_weight_norms = kwargs.get("scale_weight_norms") + if scale_weight_norms > 0.0: + run_cmd += f' --scale_weight_norms="{scale_weight_norms}"' + + if "seed" in kwargs: + seed = kwargs.get("seed") + if seed != "": + run_cmd += f' --seed="{seed}"' + + if "shuffle_caption" in kwargs: + if kwargs.get("shuffle_caption"): + run_cmd += " --shuffle_caption" + + if "stop_text_encoder_training" in kwargs: + stop_text_encoder_training = kwargs.get("stop_text_encoder_training") + if stop_text_encoder_training > 0: + run_cmd += f' --stop_text_encoder_training="{stop_text_encoder_training}"' + + if "text_encoder_lr" in kwargs: + text_encoder_lr = kwargs.get("text_encoder_lr") + if float(text_encoder_lr) > 0: + run_cmd += f" --text_encoder_lr={text_encoder_lr}" + + if "train_batch_size" in kwargs: + run_cmd += rf' --train_batch_size="{kwargs.get("train_batch_size")}"' training_comment = kwargs.get("training_comment") if training_comment and len(training_comment): @@ -1616,6 +1591,16 @@ def SaveConfigFile( if name not in exclusion } + # Check if the folder path for the file_path is valid + # Extrach folder path + folder_path = os.path.dirname(file_path) + + # Check if the folder exists + if not os.path.exists(folder_path): + # If not, create the folder + os.makedirs(os.path.dirname(folder_path)) + log.info(f"Creating folder {folder_path} for the configuration file...") + # Save the data to the specified JSON file with open(file_path, "w") as file: json.dump(variables, file, indent=2) diff --git a/kohya_gui/convert_lcm_gui.py b/kohya_gui/convert_lcm_gui.py index 81b0f5d35..dbc928e45 100644 --- a/kohya_gui/convert_lcm_gui.py +++ b/kohya_gui/convert_lcm_gui.py @@ -22,17 +22,12 @@ PYTHON = sys.executable -def convert_lcm( - name, - model_path, - lora_scale, - model_type -): - run_cmd = fr'"{PYTHON}" "{scriptdir}/tools/lcm_convert.py"' +def convert_lcm(name, model_path, lora_scale, model_type): + run_cmd = rf'"{PYTHON}" "{scriptdir}/tools/lcm_convert.py"' # Check if source model exist if not os.path.isfile(model_path): - log.error('The provided DyLoRA model is not a file') + log.error("The provided DyLoRA model is not a file") return if os.path.dirname(name) == "": @@ -46,12 +41,11 @@ def convert_lcm( path, ext = os.path.splitext(save_to) save_to = f"{path}_lcm{ext}" - # Construct the command to run the script run_cmd += f" --lora-scale {lora_scale}" run_cmd += f' --model "{model_path}"' run_cmd += f' --name "{name}"' - + if model_type == "SDXL": run_cmd += f" --sdxl" if model_type == "SSD-1B": @@ -60,7 +54,9 @@ def convert_lcm( log.info(run_cmd) env = os.environ.copy() - env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + env["PYTHONPATH"] = ( + rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + ) # Run the command subprocess.run(run_cmd, shell=True, env=env) @@ -98,11 +94,16 @@ def list_save_to(path): value="", allow_custom_value=True, ) - create_refresh_button(model_path, lambda: None, lambda: {"choices": list_models(current_model_dir)}, "open_folder_small") + create_refresh_button( + model_path, + lambda: None, + lambda: {"choices": list_models(current_model_dir)}, + "open_folder_small", + ) button_model_path_file = gr.Button( folder_symbol, elem_id="open_folder_small", - elem_classes=['tool'], + elem_classes=["tool"], visible=(not headless), ) button_model_path_file.click( @@ -119,11 +120,16 @@ def list_save_to(path): value="", allow_custom_value=True, ) - create_refresh_button(name, lambda: None, lambda: {"choices": list_save_to(current_save_dir)}, "open_folder_small") + create_refresh_button( + name, + lambda: None, + lambda: {"choices": list_save_to(current_save_dir)}, + "open_folder_small", + ) button_name = gr.Button( folder_symbol, elem_id="open_folder_small", - elem_classes=['tool'], + elem_classes=["tool"], visible=(not headless), ) button_name.click( @@ -154,7 +160,7 @@ def list_save_to(path): value=1.0, interactive=True, ) - # with gr.Row(): + # with gr.Row(): # no_half = gr.Checkbox(label="Convert the new LCM model to FP32", value=False) model_type = gr.Radio( label="Model type", choices=["SD15", "SDXL", "SD-1B"], value="SD15" @@ -164,11 +170,6 @@ def list_save_to(path): extract_button.click( convert_lcm, - inputs=[ - name, - model_path, - lora_scale, - model_type - ], + inputs=[name, model_path, lora_scale, model_type], show_progress=False, ) diff --git a/kohya_gui/convert_model_gui.py b/kohya_gui/convert_model_gui.py index 500e51066..f8fec7473 100644 --- a/kohya_gui/convert_model_gui.py +++ b/kohya_gui/convert_model_gui.py @@ -2,7 +2,6 @@ from easygui import msgbox import subprocess import os -import shutil import sys from .common_gui import get_folder_path, get_file_path, scriptdir, list_files, list_dirs @@ -11,10 +10,10 @@ # Set up logging log = setup_logging() -folder_symbol = '\U0001f4c2' # 📂 -refresh_symbol = '\U0001f504' # 🔄 -save_style_symbol = '\U0001f4be' # 💾 -document_symbol = '\U0001F4C4' # 📄 +folder_symbol = "\U0001f4c2" # 📂 +refresh_symbol = "\U0001f504" # 🔄 +save_style_symbol = "\U0001f4be" # 💾 +document_symbol = "\U0001F4C4" # 📄 PYTHON = sys.executable @@ -29,52 +28,51 @@ def convert_model( unet_use_linear_projection, ): # Check for caption_text_input - if source_model_type == '': - msgbox('Invalid source model type') + if source_model_type == "": + msgbox("Invalid source model type") return # Check if source model exist if os.path.isfile(source_model_input): - log.info('The provided source model is a file') + log.info("The provided source model is a file") elif os.path.isdir(source_model_input): - log.info('The provided model is a folder') + log.info("The provided model is a folder") else: - msgbox('The provided source model is neither a file nor a folder') + msgbox("The provided source model is neither a file nor a folder") return # Check if source model exist if os.path.isdir(target_model_folder_input): - log.info('The provided model folder exist') + log.info("The provided model folder exist") else: - msgbox('The provided target folder does not exist') + msgbox("The provided target folder does not exist") return - run_cmd = fr'"{PYTHON}" "{scriptdir}/sd-scripts/tools/convert_diffusers20_original_sd.py"' + run_cmd = ( + rf'"{PYTHON}" "{scriptdir}/sd-scripts/tools/convert_diffusers20_original_sd.py"' + ) v1_models = [ - 'runwayml/stable-diffusion-v1-5', - 'CompVis/stable-diffusion-v1-4', + "runwayml/stable-diffusion-v1-5", + "CompVis/stable-diffusion-v1-4", ] # check if v1 models if str(source_model_type) in v1_models: - log.info('SD v1 model specified. Setting --v1 parameter') - run_cmd += ' --v1' + log.info("SD v1 model specified. Setting --v1 parameter") + run_cmd += " --v1" else: - log.info('SD v2 model specified. Setting --v2 parameter') - run_cmd += ' --v2' + log.info("SD v2 model specified. Setting --v2 parameter") + run_cmd += " --v2" - if not target_save_precision_type == 'unspecified': - run_cmd += f' --{target_save_precision_type}' + if not target_save_precision_type == "unspecified": + run_cmd += f" --{target_save_precision_type}" - if ( - target_model_type == 'diffuser' - or target_model_type == 'diffuser_safetensors' - ): + if target_model_type == "diffuser" or target_model_type == "diffuser_safetensors": run_cmd += f' --reference_model="{source_model_type}"' - if target_model_type == 'diffuser_safetensors': - run_cmd += ' --use_safetensors' + if target_model_type == "diffuser_safetensors": + run_cmd += " --use_safetensors" # Fix for stabilityAI diffusers format. When saving v2 models in Diffusers format in training scripts and conversion scripts, # it was found that the U-Net configuration is different from those of Hugging Face's stabilityai models (this repository is @@ -82,14 +80,11 @@ def convert_model( # when using the weight files directly. if unet_use_linear_projection: - run_cmd += ' --unet_use_linear_projection' + run_cmd += " --unet_use_linear_projection" run_cmd += f' "{source_model_input}"' - if ( - target_model_type == 'diffuser' - or target_model_type == 'diffuser_safetensors' - ): + if target_model_type == "diffuser" or target_model_type == "diffuser_safetensors": target_model_path = os.path.join( target_model_folder_input, target_model_name_input ) @@ -97,74 +92,20 @@ def convert_model( else: target_model_path = os.path.join( target_model_folder_input, - f'{target_model_name_input}.{target_model_type}', + f"{target_model_name_input}.{target_model_type}", ) run_cmd += f' "{target_model_path}"' log.info(run_cmd) env = os.environ.copy() - env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + env["PYTHONPATH"] = ( + rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + ) # Run the command subprocess.run(run_cmd, shell=True, env=env) - if ( - not target_model_type == 'diffuser' - or target_model_type == 'diffuser_safetensors' - ): - - v2_models = [ - 'stabilityai/stable-diffusion-2-1-base', - 'stabilityai/stable-diffusion-2-base', - ] - v_parameterization = [ - 'stabilityai/stable-diffusion-2-1', - 'stabilityai/stable-diffusion-2', - ] - - if str(source_model_type) in v2_models: - inference_file = os.path.join( - target_model_folder_input, f'{target_model_name_input}.yaml' - ) - log.info(f'Saving v2-inference.yaml as {inference_file}') - shutil.copy( - fr'{scriptdir}/v2_inference/v2-inference.yaml', - f'{inference_file}', - ) - - if str(source_model_type) in v_parameterization: - inference_file = os.path.join( - target_model_folder_input, f'{target_model_name_input}.yaml' - ) - log.info(f'Saving v2-inference-v.yaml as {inference_file}') - shutil.copy( - fr'{scriptdir}/v2_inference/v2-inference-v.yaml', - f'{inference_file}', - ) - - -# parser = argparse.ArgumentParser() -# parser.add_argument("--v1", action='store_true', -# help='load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む') -# parser.add_argument("--v2", action='store_true', -# help='load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む') -# parser.add_argument("--fp16", action='store_true', -# help='load as fp16 (Diffusers only) and save as fp16 (checkpoint only) / fp16形式で読み込み(Diffusers形式のみ対応)、保存する(checkpointのみ対応)') -# parser.add_argument("--bf16", action='store_true', help='save as bf16 (checkpoint only) / bf16形式で保存する(checkpointのみ対応)') -# parser.add_argument("--float", action='store_true', -# help='save as float (checkpoint only) / float(float32)形式で保存する(checkpointのみ対応)') -# parser.add_argument("--epoch", type=int, default=0, help='epoch to write to checkpoint / checkpointに記録するepoch数の値') -# parser.add_argument("--global_step", type=int, default=0, -# help='global_step to write to checkpoint / checkpointに記録するglobal_stepの値') -# parser.add_argument("--reference_model", type=str, default=None, -# help="reference model for schduler/tokenizer, required in saving Diffusers, copy schduler/tokenizer from this / scheduler/tokenizerのコピー元のDiffusersモデル、Diffusers形式で保存するときに必要") - -# parser.add_argument("model_to_load", type=str, default=None, -# help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ") -# parser.add_argument("model_to_save", type=str, default=None, -# help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存") - ### # Gradio UI @@ -189,124 +130,136 @@ def list_target_folder(path): current_target_folder = path return list(list_dirs(path)) - with gr.Tab('Convert model'): + with gr.Tab("Convert model"): gr.Markdown( - 'This utility can be used to convert from one stable diffusion model format to another.' + "This utility can be used to convert from one stable diffusion model format to another." ) - model_ext = gr.Textbox(value='*.safetensors *.ckpt', visible=False) - model_ext_name = gr.Textbox(value='Model types', visible=False) + model_ext = gr.Textbox(value="*.safetensors *.ckpt", visible=False) + model_ext_name = gr.Textbox(value="Model types", visible=False) with gr.Group(), gr.Row(): - with gr.Column(), gr.Row(): - source_model_input = gr.Dropdown( - label='Source model (path to source model folder of file to convert...)', - interactive=True, - choices=[""] + list_source_model(default_source_model), - value="", - allow_custom_value=True, - ) - create_refresh_button(source_model_input, lambda: None, lambda: {"choices": list_source_model(current_source_model)}, "open_folder_small") - button_source_model_dir = gr.Button( - folder_symbol, - elem_id='open_folder_small', - elem_classes=['tool'], - visible=(not headless), - ) - button_source_model_dir.click( - get_folder_path, - outputs=source_model_input, - show_progress=False, - ) - - button_source_model_file = gr.Button( - document_symbol, - elem_id='open_folder_small', - elem_classes=['tool'], - visible=(not headless), - ) - button_source_model_file.click( - get_file_path, - inputs=[source_model_input, model_ext, model_ext_name], - outputs=source_model_input, - show_progress=False, - ) - - source_model_input.change( - fn=lambda path: gr.Dropdown(choices=[""] + list_source_model(path)), - inputs=source_model_input, - outputs=source_model_input, - show_progress=False, - ) - with gr.Column(), gr.Row(): - source_model_type = gr.Dropdown( - label='Source model type', - choices=[ - 'stabilityai/stable-diffusion-2-1-base', - 'stabilityai/stable-diffusion-2-base', - 'stabilityai/stable-diffusion-2-1', - 'stabilityai/stable-diffusion-2', - 'runwayml/stable-diffusion-v1-5', - 'CompVis/stable-diffusion-v1-4', - ], - ) + with gr.Column(), gr.Row(): + source_model_input = gr.Dropdown( + label="Source model (path to source model folder of file to convert...)", + interactive=True, + choices=[""] + list_source_model(default_source_model), + value="", + allow_custom_value=True, + ) + create_refresh_button( + source_model_input, + lambda: None, + lambda: {"choices": list_source_model(current_source_model)}, + "open_folder_small", + ) + button_source_model_dir = gr.Button( + folder_symbol, + elem_id="open_folder_small", + elem_classes=["tool"], + visible=(not headless), + ) + button_source_model_dir.click( + get_folder_path, + outputs=source_model_input, + show_progress=False, + ) + + button_source_model_file = gr.Button( + document_symbol, + elem_id="open_folder_small", + elem_classes=["tool"], + visible=(not headless), + ) + button_source_model_file.click( + get_file_path, + inputs=[source_model_input, model_ext, model_ext_name], + outputs=source_model_input, + show_progress=False, + ) + + source_model_input.change( + fn=lambda path: gr.Dropdown(choices=[""] + list_source_model(path)), + inputs=source_model_input, + outputs=source_model_input, + show_progress=False, + ) + with gr.Column(), gr.Row(): + source_model_type = gr.Dropdown( + label="Source model type", + choices=[ + "stabilityai/stable-diffusion-2-1-base", + "stabilityai/stable-diffusion-2-base", + "stabilityai/stable-diffusion-2-1", + "stabilityai/stable-diffusion-2", + "runwayml/stable-diffusion-v1-5", + "CompVis/stable-diffusion-v1-4", + ], + ) with gr.Group(), gr.Row(): - with gr.Column(), gr.Row(): - target_model_folder_input = gr.Dropdown( - label='Target model folder (path to target model folder of file name to create...)', - interactive=True, - choices=[""] + list_target_folder(default_target_folder), - value="", - allow_custom_value=True, - ) - create_refresh_button(target_model_folder_input, lambda: None, lambda: {"choices": list_target_folder(current_target_folder)},"open_folder_small") - button_target_model_folder = gr.Button( - folder_symbol, - elem_id='open_folder_small', - elem_classes=['tool'], - visible=(not headless), - ) - button_target_model_folder.click( - get_folder_path, - outputs=target_model_folder_input, - show_progress=False, - ) - - target_model_folder_input.change( - fn=lambda path: gr.Dropdown(choices=[""] + list_target_folder(path)), - inputs=target_model_folder_input, - outputs=target_model_folder_input, - show_progress=False, - ) - - with gr.Column(), gr.Row(): - target_model_name_input = gr.Textbox( - label='Target model name', - placeholder='target model name...', - interactive=True, - ) + with gr.Column(), gr.Row(): + target_model_folder_input = gr.Dropdown( + label="Target model folder (path to target model folder of file name to create...)", + interactive=True, + choices=[""] + list_target_folder(default_target_folder), + value="", + allow_custom_value=True, + ) + create_refresh_button( + target_model_folder_input, + lambda: None, + lambda: {"choices": list_target_folder(current_target_folder)}, + "open_folder_small", + ) + button_target_model_folder = gr.Button( + folder_symbol, + elem_id="open_folder_small", + elem_classes=["tool"], + visible=(not headless), + ) + button_target_model_folder.click( + get_folder_path, + outputs=target_model_folder_input, + show_progress=False, + ) + + target_model_folder_input.change( + fn=lambda path: gr.Dropdown( + choices=[""] + list_target_folder(path) + ), + inputs=target_model_folder_input, + outputs=target_model_folder_input, + show_progress=False, + ) + + with gr.Column(), gr.Row(): + target_model_name_input = gr.Textbox( + label="Target model name", + placeholder="target model name...", + interactive=True, + ) with gr.Row(): target_model_type = gr.Dropdown( - label='Target model type', + label="Target model type", choices=[ - 'diffuser', - 'diffuser_safetensors', - 'ckpt', - 'safetensors', + "diffuser", + "diffuser_safetensors", + "ckpt", + "safetensors", ], ) target_save_precision_type = gr.Dropdown( - label='Target model precision', - choices=['unspecified', 'fp16', 'bf16', 'float'], - value='unspecified', + label="Target model precision", + choices=["unspecified", "fp16", "bf16", "float"], + value="unspecified", ) unet_use_linear_projection = gr.Checkbox( - label='UNet linear projection', + label="UNet linear projection", value=False, info="Enable for Hugging Face's stabilityai models", ) - convert_button = gr.Button('Convert model') + convert_button = gr.Button("Convert model") convert_button.click( convert_model, diff --git a/kohya_gui/custom_logging.py b/kohya_gui/custom_logging.py index ee7e5e208..2dd9ecbf1 100644 --- a/kohya_gui/custom_logging.py +++ b/kohya_gui/custom_logging.py @@ -11,34 +11,71 @@ log = None + def setup_logging(clean=False, debug=False): global log - + if log is not None: return log - + try: - if clean and os.path.isfile('setup.log'): - os.remove('setup.log') - time.sleep(0.1) # prevent race condition + if clean and os.path.isfile("setup.log"): + os.remove("setup.log") + time.sleep(0.1) # prevent race condition except: pass - + if sys.version_info >= (3, 9): - logging.basicConfig(level=logging.DEBUG, format='%(asctime)s | %(levelname)s | %(pathname)s | %(message)s', filename='setup.log', filemode='a', encoding='utf-8', force=True) + logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s | %(levelname)s | %(pathname)s | %(message)s", + filename="setup.log", + filemode="a", + encoding="utf-8", + force=True, + ) else: - logging.basicConfig(level=logging.DEBUG, format='%(asctime)s | %(levelname)s | %(pathname)s | %(message)s', filename='setup.log', filemode='a', force=True) - - console = Console(log_time=True, log_time_format='%H:%M:%S-%f', theme=Theme({ - "traceback.border": "black", - "traceback.border.syntax_error": "black", - "inspect.value.border": "black", - })) + logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s | %(levelname)s | %(pathname)s | %(message)s", + filename="setup.log", + filemode="a", + force=True, + ) + + console = Console( + log_time=True, + log_time_format="%H:%M:%S-%f", + theme=Theme( + { + "traceback.border": "black", + "traceback.border.syntax_error": "black", + "inspect.value.border": "black", + } + ), + ) pretty_install(console=console) - traceback_install(console=console, extra_lines=1, width=console.width, word_wrap=False, indent_guides=False, suppress=[]) - rh = RichHandler(show_time=True, omit_repeated_times=False, show_level=True, show_path=False, markup=False, rich_tracebacks=True, log_time_format='%H:%M:%S-%f', level=logging.DEBUG if debug else logging.INFO, console=console) + traceback_install( + console=console, + extra_lines=1, + width=console.width, + word_wrap=False, + indent_guides=False, + suppress=[], + ) + rh = RichHandler( + show_time=True, + omit_repeated_times=False, + show_level=True, + show_path=False, + markup=False, + rich_tracebacks=True, + log_time_format="%H:%M:%S-%f", + level=logging.DEBUG if debug else logging.INFO, + console=console, + ) rh.set_name(logging.DEBUG if debug else logging.INFO) log = logging.getLogger("sd") log.addHandler(rh) - + return log diff --git a/kohya_gui/dataset_balancing_gui.py b/kohya_gui/dataset_balancing_gui.py index a56594b63..8d644d1c1 100644 --- a/kohya_gui/dataset_balancing_gui.py +++ b/kohya_gui/dataset_balancing_gui.py @@ -9,29 +9,22 @@ # Set up logging log = setup_logging() -# def select_folder(): -# # Open a file dialog to select a directory -# folder = filedialog.askdirectory() - -# # Update the GUI to display the selected folder -# selected_folder_label.config(text=folder) - def dataset_balancing(concept_repeats, folder, insecure): if not concept_repeats > 0: # Display an error message if the total number of repeats is not a valid integer - msgbox('Please enter a valid integer for the total number of repeats.') + msgbox("Please enter a valid integer for the total number of repeats.") return concept_repeats = int(concept_repeats) # Check if folder exist - if folder == '' or not os.path.isdir(folder): - msgbox('Please enter a valid folder for balancing.') + if folder == "" or not os.path.isdir(folder): + msgbox("Please enter a valid folder for balancing.") return - pattern = re.compile(r'^\d+_.+$') + pattern = re.compile(r"^\d+_.+$") # Iterate over the subdirectories in the selected folder for subdir in os.listdir(folder): @@ -44,7 +37,7 @@ def dataset_balancing(concept_repeats, folder, insecure): image_files = [ f for f in files - if f.endswith(('.jpg', '.jpeg', '.png', '.gif', '.webp')) + if f.endswith((".jpg", ".jpeg", ".png", ".gif", ".webp")) ] # Count the number of image files @@ -52,20 +45,18 @@ def dataset_balancing(concept_repeats, folder, insecure): if images == 0: log.info( - f'No images of type .jpg, .jpeg, .png, .gif, .webp were found in {os.listdir(os.path.join(folder, subdir))}' + f"No images of type .jpg, .jpeg, .png, .gif, .webp were found in {os.listdir(os.path.join(folder, subdir))}" ) # Check if the subdirectory name starts with a number inside braces, # indicating that the repeats value should be multiplied - match = re.match(r'^\{(\d+\.?\d*)\}', subdir) + match = re.match(r"^\{(\d+\.?\d*)\}", subdir) if match: # Multiply the repeats value by the number inside the braces if not images == 0: repeats = max( 1, - round( - concept_repeats / images * float(match.group(1)) - ), + round(concept_repeats / images * float(match.group(1))), ) else: repeats = 0 @@ -77,32 +68,30 @@ def dataset_balancing(concept_repeats, folder, insecure): repeats = 0 # Check if the subdirectory name already has a number at the beginning - match = re.match(r'^\d+_', subdir) + match = re.match(r"^\d+_", subdir) if match: # Replace the existing number with the new number old_name = os.path.join(folder, subdir) - new_name = os.path.join( - folder, f'{repeats}_{subdir[match.end():]}' - ) + new_name = os.path.join(folder, f"{repeats}_{subdir[match.end():]}") else: # Add the new number at the beginning of the name old_name = os.path.join(folder, subdir) - new_name = os.path.join(folder, f'{repeats}_{subdir}') + new_name = os.path.join(folder, f"{repeats}_{subdir}") os.rename(old_name, new_name) else: log.info( - f'Skipping folder {subdir} because it does not match kohya_ss expected syntax...' + f"Skipping folder {subdir} because it does not match kohya_ss expected syntax..." ) - msgbox('Dataset balancing completed...') + msgbox("Dataset balancing completed...") def warning(insecure): if insecure: if boolbox( - f'WARNING!!! You have asked to rename non kohya_ss _ folders...\n\nAre you sure you want to do that?', - choices=('Yes, I like danger', 'No, get me out of here'), + f"WARNING!!! You have asked to rename non kohya_ss _ folders...\n\nAre you sure you want to do that?", + choices=("Yes, I like danger", "No, get me out of here"), ): return True else: @@ -113,12 +102,12 @@ def gradio_dataset_balancing_tab(headless=False): current_dataset_dir = os.path.join(scriptdir, "data") - with gr.Tab('Dreambooth/LoRA Dataset balancing'): + with gr.Tab("Dreambooth/LoRA Dataset balancing"): gr.Markdown( - 'This utility will ensure that each concept folder in the dataset folder is used equally during the training process of the dreambooth machine learning model, regardless of the number of images in each folder. It will do this by renaming the concept folders to indicate the number of times they should be repeated during training.' + "This utility will ensure that each concept folder in the dataset folder is used equally during the training process of the dreambooth machine learning model, regardless of the number of images in each folder. It will do this by renaming the concept folders to indicate the number of times they should be repeated during training." ) gr.Markdown( - 'WARNING! The use of this utility on the wrong folder can lead to unexpected folder renaming!!!' + "WARNING! The use of this utility on the wrong folder can lead to unexpected folder renaming!!!" ) with gr.Group(), gr.Row(): @@ -128,15 +117,23 @@ def list_dataset_dirs(path): return list(list_dirs(path)) select_dataset_folder_input = gr.Dropdown( - label='Dataset folder (folder containing the concepts folders to balance...)', + label="Dataset folder (folder containing the concepts folders to balance...)", interactive=True, choices=[""] + list_dataset_dirs(current_dataset_dir), value="", allow_custom_value=True, ) - create_refresh_button(select_dataset_folder_input, lambda: None, lambda: {"choices": list_dataset_dirs(current_dataset_dir)}, "open_folder_small") + create_refresh_button( + select_dataset_folder_input, + lambda: None, + lambda: {"choices": list_dataset_dirs(current_dataset_dir)}, + "open_folder_small", + ) select_dataset_folder_button = gr.Button( - '📂', elem_id='open_folder_small', elem_classes=['tool'], visible=(not headless) + "📂", + elem_id="open_folder_small", + elem_classes=["tool"], + visible=(not headless), ) select_dataset_folder_button.click( get_folder_path, @@ -147,7 +144,7 @@ def list_dataset_dirs(path): total_repeats_number = gr.Number( value=1000, interactive=True, - label='Training steps per concept per epoch', + label="Training steps per concept per epoch", ) select_dataset_folder_input.change( fn=lambda path: gr.Dropdown(choices=[""] + list_dataset_dirs(path)), @@ -156,13 +153,13 @@ def list_dataset_dirs(path): show_progress=False, ) - with gr.Accordion('Advanced options', open=False): + with gr.Accordion("Advanced options", open=False): insecure = gr.Checkbox( value=False, - label='DANGER!!! -- Insecure folder renaming -- DANGER!!!', + label="DANGER!!! -- Insecure folder renaming -- DANGER!!!", ) insecure.change(warning, inputs=insecure, outputs=insecure) - balance_button = gr.Button('Balance dataset') + balance_button = gr.Button("Balance dataset") balance_button.click( dataset_balancing, inputs=[ diff --git a/kohya_gui/dreambooth_folder_creation_gui.py b/kohya_gui/dreambooth_folder_creation_gui.py index 24af3c139..1a0fd3a98 100644 --- a/kohya_gui/dreambooth_folder_creation_gui.py +++ b/kohya_gui/dreambooth_folder_creation_gui.py @@ -1,8 +1,9 @@ import gradio as gr -from easygui import diropenbox, msgbox +from easygui import msgbox from .common_gui import get_folder_path, scriptdir, list_dirs, create_refresh_button import shutil import os +from .class_gui_config import KohyaSSGUIConfig from .custom_logging import setup_logging @@ -11,13 +12,13 @@ def copy_info_to_Folders_tab(training_folder): - img_folder = os.path.join(training_folder, 'img') - if os.path.exists(os.path.join(training_folder, 'reg')): - reg_folder = os.path.join(training_folder, 'reg') + img_folder = os.path.join(training_folder, "img") + if os.path.exists(os.path.join(training_folder, "reg")): + reg_folder = os.path.join(training_folder, "reg") else: - reg_folder = '' - model_folder = os.path.join(training_folder, 'model') - log_folder = os.path.join(training_folder, 'log') + reg_folder = "" + model_folder = os.path.join(training_folder, "model") + log_folder = os.path.join(training_folder, "log") return img_folder, reg_folder, model_folder, log_folder @@ -43,17 +44,17 @@ def dreambooth_folder_preparation( os.makedirs(util_training_dir_output, exist_ok=True) # Check for instance prompt - if util_instance_prompt_input == '': - msgbox('Instance prompt missing...') + if util_instance_prompt_input == "": + msgbox("Instance prompt missing...") return # Check for class prompt - if util_class_prompt_input == '': - msgbox('Class prompt missing...') + if util_class_prompt_input == "": + msgbox("Class prompt missing...") return # Create the training_dir path - if util_training_images_dir_input == '': + if util_training_images_dir_input == "": log.info( "Training images directory is missing... can't perform the required task..." ) @@ -61,64 +62,59 @@ def dreambooth_folder_preparation( else: training_dir = os.path.join( util_training_dir_output, - f'img/{int(util_training_images_repeat_input)}_{util_instance_prompt_input} {util_class_prompt_input}', + f"img/{int(util_training_images_repeat_input)}_{util_instance_prompt_input} {util_class_prompt_input}", ) # Remove folders if they exist if os.path.exists(training_dir): - log.info(f'Removing existing directory {training_dir}...') + log.info(f"Removing existing directory {training_dir}...") shutil.rmtree(training_dir) # Copy the training images to their respective directories - log.info(f'Copy {util_training_images_dir_input} to {training_dir}...') + log.info(f"Copy {util_training_images_dir_input} to {training_dir}...") shutil.copytree(util_training_images_dir_input, training_dir) - if not util_regularization_images_dir_input == '': + if not util_regularization_images_dir_input == "": # Create the regularization_dir path if not util_regularization_images_repeat_input > 0: - log.info( - 'Repeats is missing... not copying regularisation images...' - ) + log.info("Repeats is missing... not copying regularisation images...") else: regularization_dir = os.path.join( util_training_dir_output, - f'reg/{int(util_regularization_images_repeat_input)}_{util_class_prompt_input}', + f"reg/{int(util_regularization_images_repeat_input)}_{util_class_prompt_input}", ) # Remove folders if they exist if os.path.exists(regularization_dir): - log.info( - f'Removing existing directory {regularization_dir}...' - ) + log.info(f"Removing existing directory {regularization_dir}...") shutil.rmtree(regularization_dir) # Copy the regularisation images to their respective directories log.info( - f'Copy {util_regularization_images_dir_input} to {regularization_dir}...' - ) - shutil.copytree( - util_regularization_images_dir_input, regularization_dir + f"Copy {util_regularization_images_dir_input} to {regularization_dir}..." ) + shutil.copytree(util_regularization_images_dir_input, regularization_dir) else: log.info( - 'Regularization images directory is missing... not copying regularisation images...' + "Regularization images directory is missing... not copying regularisation images..." ) # create log and model folder # Check if the log folder exists and create it if it doesn't - if not os.path.exists(os.path.join(util_training_dir_output, 'log')): - os.makedirs(os.path.join(util_training_dir_output, 'log')) + if not os.path.exists(os.path.join(util_training_dir_output, "log")): + os.makedirs(os.path.join(util_training_dir_output, "log")) # Check if the model folder exists and create it if it doesn't - if not os.path.exists(os.path.join(util_training_dir_output, 'model')): - os.makedirs(os.path.join(util_training_dir_output, 'model')) + if not os.path.exists(os.path.join(util_training_dir_output, "model")): + os.makedirs(os.path.join(util_training_dir_output, "model")) log.info( - f'Done creating kohya_ss training folder structure at {util_training_dir_output}...' + f"Done creating kohya_ss training folder structure at {util_training_dir_output}..." ) def gradio_dreambooth_folder_creation_tab( + config: KohyaSSGUIConfig, train_data_dir_input=gr.Dropdown(), reg_data_dir_input=gr.Dropdown(), output_dir_input=gr.Dropdown(), @@ -130,20 +126,22 @@ def gradio_dreambooth_folder_creation_tab( current_reg_data_dir = os.path.join(scriptdir, "data") current_train_output_dir = os.path.join(scriptdir, "data") - with gr.Tab('Dreambooth/LoRA Folder preparation'): + with gr.Tab("Dreambooth/LoRA Folder preparation"): gr.Markdown( - 'This utility will create the necessary folder structure for the training images and optional regularization images needed for the kohys_ss Dreambooth/LoRA method to function correctly.' + "This utility will create the necessary folder structure for the training images and optional regularization images needed for the kohys_ss Dreambooth/LoRA method to function correctly." ) with gr.Row(): util_instance_prompt_input = gr.Textbox( - label='Instance prompt', - placeholder='Eg: asd', + label="Instance prompt", + placeholder="Eg: asd", interactive=True, + value=config.get(key="dataset_preparation.instance_prompt", default=""), ) util_class_prompt_input = gr.Textbox( - label='Class prompt', - placeholder='Eg: person', + label="Class prompt", + placeholder="Eg: person", interactive=True, + value=config.get(key="dataset_preparation.class_prompt", default=""), ) with gr.Group(), gr.Row(): @@ -153,15 +151,26 @@ def list_train_data_dirs(path): return list(list_dirs(path)) util_training_images_dir_input = gr.Dropdown( - label='Training images (directory containing the training images)', + label="Training images (directory containing the training images)", interactive=True, - choices=[""] + list_train_data_dirs(current_train_data_dir), - value="", + choices=[ + config.get(key="dataset_preparation.images_folder", default="") + ] + + list_train_data_dirs(current_train_data_dir), + value=config.get(key="dataset_preparation.images_folder", default=""), allow_custom_value=True, ) - create_refresh_button(util_training_images_dir_input, lambda: None, lambda: {"choices": list_train_data_dirs(current_train_data_dir)}, "open_folder_small") + create_refresh_button( + util_training_images_dir_input, + lambda: None, + lambda: {"choices": list_train_data_dirs(current_train_data_dir)}, + "open_folder_small", + ) button_util_training_images_dir_input = gr.Button( - '📂', elem_id='open_folder_small', elem_classes=['tool'], visible=(not headless) + "📂", + elem_id="open_folder_small", + elem_classes=["tool"], + visible=(not headless), ) button_util_training_images_dir_input.click( get_folder_path, @@ -169,34 +178,48 @@ def list_train_data_dirs(path): show_progress=False, ) util_training_images_repeat_input = gr.Number( - label='Repeats', - value=40, + label="Repeats", + value=config.get(key="dataset_preparation.util_training_images_repeat_input", default=40), interactive=True, - elem_id='number_input', + elem_id="number_input", ) util_training_images_dir_input.change( - fn=lambda path: gr.Dropdown(choices=[""] + list_train_data_dirs(path)), + fn=lambda path: gr.Dropdown(choices=[config.get(key="dataset_preparation.images_folder", default="")] + list_train_data_dirs(path)), inputs=util_training_images_dir_input, outputs=util_training_images_dir_input, show_progress=False, ) with gr.Group(), gr.Row(): + def list_reg_data_dirs(path): nonlocal current_reg_data_dir current_reg_data_dir = path return list(list_dirs(path)) util_regularization_images_dir_input = gr.Dropdown( - label='Regularisation images (Optional. directory containing the regularisation images)', + label="Regularisation images (Optional. directory containing the regularisation images)", interactive=True, - choices=[""] + list_reg_data_dirs(current_reg_data_dir), - value="", + choices=[ + config.get(key="dataset_preparation.reg_images_folder", default="") + ] + + list_reg_data_dirs(current_reg_data_dir), + value=config.get( + key="dataset_preparation.reg_images_folder", default="" + ), allow_custom_value=True, ) - create_refresh_button(util_regularization_images_dir_input, lambda: None, lambda: {"choices": list_reg_data_dir(current_reg_data_dir)}, "open_folder_small") + create_refresh_button( + util_regularization_images_dir_input, + lambda: None, + lambda: {"choices": list_reg_data_dirs(current_reg_data_dir)}, + "open_folder_small", + ) button_util_regularization_images_dir_input = gr.Button( - '📂', elem_id='open_folder_small', elem_classes=['tool'], visible=(not headless) + "📂", + elem_id="open_folder_small", + elem_classes=["tool"], + visible=(not headless), ) button_util_regularization_images_dir_input.click( get_folder_path, @@ -204,10 +227,13 @@ def list_reg_data_dirs(path): show_progress=False, ) util_regularization_images_repeat_input = gr.Number( - label='Repeats', - value=1, + label="Repeats", + value=config.get( + key="dataset_preparation.util_regularization_images_repeat_input", + default=1 + ), interactive=True, - elem_id='number_input', + elem_id="number_input", ) util_regularization_images_dir_input.change( fn=lambda path: gr.Dropdown(choices=[""] + list_reg_data_dirs(path)), @@ -216,32 +242,44 @@ def list_reg_data_dirs(path): show_progress=False, ) with gr.Group(), gr.Row(): + def list_train_output_dirs(path): nonlocal current_train_output_dir current_train_output_dir = path return list(list_dirs(path)) util_training_dir_output = gr.Dropdown( - label='Destination training directory (where formatted training and regularisation folders will be placed)', + label="Destination training directory (where formatted training and regularisation folders will be placed)", interactive=True, - choices=[""] + list_train_output_dirs(current_train_output_dir), - value="", + choices=[config.get(key="train_data_dir", default="")] + + list_train_output_dirs(current_train_output_dir), + value=config.get(key="train_data_dir", default=""), allow_custom_value=True, ) - create_refresh_button(util_training_dir_output, lambda: None, lambda: {"choices": list_train_output_dirs(current_train_output_dir)}, "open_folder_small") + create_refresh_button( + util_training_dir_output, + lambda: None, + lambda: {"choices": list_train_output_dirs(current_train_output_dir)}, + "open_folder_small", + ) button_util_training_dir_output = gr.Button( - '📂', elem_id='open_folder_small', elem_classes=['tool'], visible=(not headless) + "📂", + elem_id="open_folder_small", + elem_classes=["tool"], + visible=(not headless), ) button_util_training_dir_output.click( get_folder_path, outputs=util_training_dir_output ) util_training_dir_output.change( - fn=lambda path: gr.Dropdown(choices=[""] + list_train_output_dirs(path)), + fn=lambda path: gr.Dropdown( + choices=[config.get(key="train_data_dir", default="")] + list_train_output_dirs(path) + ), inputs=util_training_dir_output, outputs=util_training_dir_output, show_progress=False, ) - button_prepare_training_data = gr.Button('Prepare training data') + button_prepare_training_data = gr.Button("Prepare training data") button_prepare_training_data.click( dreambooth_folder_preparation, inputs=[ @@ -255,15 +293,3 @@ def list_train_output_dirs(path): ], show_progress=False, ) - button_copy_info_to_Folders_tab = gr.Button('Copy info to Folders Tab') - button_copy_info_to_Folders_tab.click( - copy_info_to_Folders_tab, - inputs=[util_training_dir_output], - outputs=[ - train_data_dir_input, - reg_data_dir_input, - output_dir_input, - logging_dir_input, - ], - show_progress=False, - ) diff --git a/kohya_gui/dreambooth_gui.py b/kohya_gui/dreambooth_gui.py index 9dae9e4dd..f870b9655 100644 --- a/kohya_gui/dreambooth_gui.py +++ b/kohya_gui/dreambooth_gui.py @@ -3,13 +3,11 @@ import math import os import sys -import pathlib from datetime import datetime from .common_gui import ( get_file_path, get_saveasfile_path, color_aug_changed, - save_inference_file, run_cmd_advanced_training, update_my_data, check_if_model_exist, @@ -18,13 +16,14 @@ scriptdir, validate_paths, ) +from .class_accelerate_launch import AccelerateLaunch from .class_configuration_file import ConfigurationFile +from .class_gui_config import KohyaSSGUIConfig from .class_source_model import SourceModel from .class_basic_training import BasicTraining from .class_advanced_training import AdvancedTraining from .class_folders import Folders from .class_command_executor import CommandExecutor -from .class_sdxl_parameters import SDXLParameters from .tensorboard_gui import ( gradio_tensorboard, start_tensorboard, @@ -89,16 +88,19 @@ def save_configuration( save_model_as, shuffle_caption, save_state, + save_state_on_train_end, resume, prior_loss_weight, color_aug, flip_aug, + masked_loss, clip_skip, vae, num_processes, num_machines, multi_gpu, gpu_ids, + main_process_port, output_name, max_token_length, max_train_epochs, @@ -122,9 +124,12 @@ def save_configuration( lr_scheduler_args, noise_offset_type, noise_offset, + noise_offset_random_strength, adaptive_noise_scale, multires_noise_iterations, multires_noise_discount, + ip_noise_gamma, + ip_noise_gamma_random_strength, sample_every_n_steps, sample_every_n_epochs, sample_sampler, @@ -144,6 +149,8 @@ def save_configuration( scale_v_pred_loss_like_noise_pred, min_timestep, max_timestep, + debiased_estimation_loss, + extra_accelerate_launch_args, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -221,16 +228,19 @@ def open_configuration( save_model_as, shuffle_caption, save_state, + save_state_on_train_end, resume, prior_loss_weight, color_aug, flip_aug, + masked_loss, clip_skip, vae, num_processes, num_machines, multi_gpu, gpu_ids, + main_process_port, output_name, max_token_length, max_train_epochs, @@ -254,9 +264,12 @@ def open_configuration( lr_scheduler_args, noise_offset_type, noise_offset, + noise_offset_random_strength, adaptive_noise_scale, multires_noise_iterations, multires_noise_discount, + ip_noise_gamma, + ip_noise_gamma_random_strength, sample_every_n_steps, sample_every_n_epochs, sample_sampler, @@ -276,6 +289,8 @@ def open_configuration( scale_v_pred_loss_like_noise_pred, min_timestep, max_timestep, + debiased_estimation_loss, + extra_accelerate_launch_args, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -348,16 +363,19 @@ def train_model( save_model_as, shuffle_caption, save_state, + save_state_on_train_end, resume, prior_loss_weight, color_aug, flip_aug, + masked_loss, clip_skip, vae, num_processes, num_machines, multi_gpu, gpu_ids, + main_process_port, output_name, max_token_length, max_train_epochs, @@ -381,9 +399,12 @@ def train_model( lr_scheduler_args, noise_offset_type, noise_offset, + noise_offset_random_strength, adaptive_noise_scale, multires_noise_iterations, multires_noise_discount, + ip_noise_gamma, + ip_noise_gamma_random_strength, sample_every_n_steps, sample_every_n_epochs, sample_sampler, @@ -403,6 +424,8 @@ def train_model( scale_v_pred_loss_like_noise_pred, min_timestep, max_timestep, + debiased_estimation_loss, + extra_accelerate_launch_args, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -434,7 +457,9 @@ def train_model( return if dataset_config: - log.info("Dataset config toml file used, skipping total_steps, train_batch_size, gradient_accumulation_steps, epoch, reg_factor, max_train_steps calculations...") + log.info( + "Dataset config toml file used, skipping total_steps, train_batch_size, gradient_accumulation_steps, epoch, reg_factor, max_train_steps calculations..." + ) else: # Get a list of all subfolders in train_data_dir, excluding hidden folders subfolders = [ @@ -484,7 +509,9 @@ def train_model( log.info(f"Folder {folder} : steps {steps}") if total_steps == 0: - log.info(f"No images were found in folder {train_data_dir}... please rectify!") + log.info( + f"No images were found in folder {train_data_dir}... please rectify!" + ) return # Print the result @@ -516,7 +543,9 @@ def train_model( # calculate stop encoder training if int(stop_text_encoder_training_pct) == -1: stop_text_encoder_training = -1 - elif stop_text_encoder_training_pct == None or (not max_train_steps == "" or not max_train_steps == "0"): + elif stop_text_encoder_training_pct == None or ( + not max_train_steps == "" or not max_train_steps == "0" + ): stop_text_encoder_training = 0 else: stop_text_encoder_training = math.ceil( @@ -533,12 +562,15 @@ def train_model( # run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "train_db.py"' run_cmd = "accelerate launch" - run_cmd += run_cmd_advanced_training( + run_cmd += AccelerateLaunch.run_cmd( num_processes=num_processes, num_machines=num_machines, multi_gpu=multi_gpu, gpu_ids=gpu_ids, + main_process_port=main_process_port, num_cpu_threads_per_process=num_cpu_threads_per_process, + mixed_precision=mixed_precision, + extra_accelerate_launch_args=extra_accelerate_launch_args, ) if sdxl: @@ -559,13 +591,17 @@ def train_model( "clip_skip": clip_skip, "color_aug": color_aug, "dataset_config": dataset_config, + "debiased_estimation_loss": debiased_estimation_loss, "enable_bucket": enable_bucket, "epoch": epoch, "flip_aug": flip_aug, + "masked_loss": masked_loss, "full_bf16": full_bf16, "full_fp16": full_fp16, "gradient_accumulation_steps": gradient_accumulation_steps, "gradient_checkpointing": gradient_checkpointing, + "ip_noise_gamma": ip_noise_gamma, + "ip_noise_gamma_random_strength": ip_noise_gamma_random_strength, "keep_tokens": keep_tokens, "learning_rate": learning_rate, "logging_dir": logging_dir, @@ -592,6 +628,7 @@ def train_model( "multires_noise_iterations": multires_noise_iterations, "no_token_padding": no_token_padding, "noise_offset": noise_offset, + "noise_offset_random_strength": noise_offset_random_strength, "noise_offset_type": noise_offset_type, "optimizer": optimizer, "optimizer_args": optimizer_args, @@ -610,6 +647,7 @@ def train_model( "save_model_as": save_model_as, "save_precision": save_precision, "save_state": save_state, + "save_state_on_train_end": save_state_on_train_end, "scale_v_pred_loss_like_noise_pred": scale_v_pred_loss_like_noise_pred, "seed": seed, "shuffle_caption": shuffle_caption, @@ -675,18 +713,12 @@ def train_model( env["PYTHONPATH"] = ( rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" ) + env["TF_ENABLE_ONEDNN_OPTS"] = "0" # Run the command executor.execute_command(run_cmd=run_cmd, env=env) - # # check if output_dir/last is a folder... therefore it is a diffuser model - # last_dir = pathlib.Path(f"{output_dir}/{output_name}") - - # if not last_dir.is_dir(): - # # Copy inference model for v2 if required - # save_inference_file(output_dir, v2, v_parameterization, output_name) - def dreambooth_tab( # train_data_dir=gr.Textbox(), @@ -694,7 +726,7 @@ def dreambooth_tab( # output_dir=gr.Textbox(), # logging_dir=gr.Textbox(), headless=False, - config: dict = {}, + config: KohyaSSGUIConfig = {}, ): dummy_db_true = gr.Label(value=True, visible=False) dummy_db_false = gr.Label(value=False, visible=False) @@ -703,23 +735,26 @@ def dreambooth_tab( with gr.Tab("Training"), gr.Column(variant="compact"): gr.Markdown("Train a custom model using kohya dreambooth python code...") + with gr.Accordion("Accelerate launch", open=False), gr.Column(): + accelerate_launch = AccelerateLaunch(config=config) + with gr.Column(): source_model = SourceModel(headless=headless, config=config) with gr.Accordion("Folders", open=False), gr.Group(): folders = Folders(headless=headless, config=config) - with gr.Accordion("Parameters", open=False), gr.Column(): - with gr.Group(elem_id="basic_tab"): - basic_training = BasicTraining( - learning_rate_value="1e-5", - lr_scheduler_value="cosine", - lr_warmup_value="10", - dreambooth=True, - sdxl_checkbox=source_model.sdxl_checkbox, - ) - # # Add SDXL Parameters - # sdxl_params = SDXLParameters(source_model.sdxl_checkbox, show_sdxl_cache_text_encoder_outputs=False) + with gr.Accordion("Parameters", open=False), gr.Column(): + with gr.Accordion("Basic", open="True"): + with gr.Group(elem_id="basic_tab"): + basic_training = BasicTraining( + learning_rate_value="1e-5", + lr_scheduler_value="cosine", + lr_warmup_value="10", + dreambooth=True, + sdxl_checkbox=source_model.sdxl_checkbox, + config=config, + ) with gr.Accordion("Advanced", open=False, elem_id="advanced_tab"): advanced_training = AdvancedTraining(headless=headless, config=config) @@ -730,7 +765,7 @@ def dreambooth_tab( ) with gr.Accordion("Samples", open=False, elem_id="samples_tab"): - sample = SampleImages() + sample = SampleImages(config=config) with gr.Accordion("Dataset Preparation", open=False): gr.Markdown( @@ -742,6 +777,7 @@ def dreambooth_tab( output_dir_input=folders.output_dir, logging_dir_input=folders.logging_dir, headless=headless, + config=config, ) gradio_dataset_balancing_tab(headless=headless) @@ -795,10 +831,10 @@ def dreambooth_tab( basic_training.train_batch_size, basic_training.epoch, basic_training.save_every_n_epochs, - basic_training.mixed_precision, + accelerate_launch.mixed_precision, source_model.save_precision, basic_training.seed, - basic_training.num_cpu_threads_per_process, + accelerate_launch.num_cpu_threads_per_process, basic_training.cache_latents, basic_training.cache_latents_to_disk, basic_training.caption_extension, @@ -814,16 +850,19 @@ def dreambooth_tab( source_model.save_model_as, advanced_training.shuffle_caption, advanced_training.save_state, + advanced_training.save_state_on_train_end, advanced_training.resume, advanced_training.prior_loss_weight, advanced_training.color_aug, advanced_training.flip_aug, + advanced_training.masked_loss, advanced_training.clip_skip, advanced_training.vae, - advanced_training.num_processes, - advanced_training.num_machines, - advanced_training.multi_gpu, - advanced_training.gpu_ids, + accelerate_launch.num_processes, + accelerate_launch.num_machines, + accelerate_launch.multi_gpu, + accelerate_launch.gpu_ids, + accelerate_launch.main_process_port, source_model.output_name, advanced_training.max_token_length, basic_training.max_train_epochs, @@ -847,9 +886,12 @@ def dreambooth_tab( basic_training.lr_scheduler_args, advanced_training.noise_offset_type, advanced_training.noise_offset, + advanced_training.noise_offset_random_strength, advanced_training.adaptive_noise_scale, advanced_training.multires_noise_iterations, advanced_training.multires_noise_discount, + advanced_training.ip_noise_gamma, + advanced_training.ip_noise_gamma_random_strength, sample.sample_every_n_steps, sample.sample_every_n_epochs, sample.sample_sampler, @@ -869,6 +911,8 @@ def dreambooth_tab( advanced_training.scale_v_pred_loss_like_noise_pred, advanced_training.min_timestep, advanced_training.max_timestep, + advanced_training.debiased_estimation_loss, + accelerate_launch.extra_accelerate_launch_args, ] configuration.button_open_config.click( diff --git a/kohya_gui/extract_lora_from_dylora_gui.py b/kohya_gui/extract_lora_from_dylora_gui.py index 3c01bd204..d99a15235 100644 --- a/kohya_gui/extract_lora_from_dylora_gui.py +++ b/kohya_gui/extract_lora_from_dylora_gui.py @@ -4,7 +4,6 @@ import os import sys from .common_gui import ( - get_saveasfilename_path, get_file_path, scriptdir, list_files, @@ -16,10 +15,10 @@ # Set up logging log = setup_logging() -folder_symbol = '\U0001f4c2' # 📂 -refresh_symbol = '\U0001f504' # 🔄 -save_style_symbol = '\U0001f4be' # 💾 -document_symbol = '\U0001F4C4' # 📄 +folder_symbol = "\U0001f4c2" # 📂 +refresh_symbol = "\U0001f504" # 🔄 +save_style_symbol = "\U0001f4be" # 💾 +document_symbol = "\U0001F4C4" # 📄 PYTHON = sys.executable @@ -30,13 +29,13 @@ def extract_dylora( unit, ): # Check for caption_text_input - if model == '': - msgbox('Invalid DyLoRA model file') + if model == "": + msgbox("Invalid DyLoRA model file") return # Check if source model exist if not os.path.isfile(model): - msgbox('The provided DyLoRA model is not a file') + msgbox("The provided DyLoRA model is not a file") return if os.path.dirname(save_to) == "": @@ -51,21 +50,23 @@ def extract_dylora( save_to = f"{path}_tmp{ext}" run_cmd = ( - fr'"{PYTHON}" "{scriptdir}/sd-scripts/networks/extract_lora_from_dylora.py"' + rf'"{PYTHON}" "{scriptdir}/sd-scripts/networks/extract_lora_from_dylora.py"' ) - run_cmd += fr' --save_to "{save_to}"' - run_cmd += fr' --model "{model}"' - run_cmd += f' --unit {unit}' + run_cmd += rf' --save_to "{save_to}"' + run_cmd += rf' --model "{model}"' + run_cmd += f" --unit {unit}" log.info(run_cmd) env = os.environ.copy() - env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + env["PYTHONPATH"] = ( + rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + ) # Run the command subprocess.run(run_cmd, shell=True, env=env) - log.info('Done extracting DyLoRA...') + log.info("Done extracting DyLoRA...") ### @@ -77,12 +78,10 @@ def gradio_extract_dylora_tab(headless=False): current_model_dir = os.path.join(scriptdir, "outputs") current_save_dir = os.path.join(scriptdir, "outputs") - with gr.Tab('Extract DyLoRA'): - gr.Markdown( - 'This utility can extract a DyLoRA network from a finetuned model.' - ) - lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False) - lora_ext_name = gr.Textbox(value='LoRA model types', visible=False) + with gr.Tab("Extract DyLoRA"): + gr.Markdown("This utility can extract a DyLoRA network from a finetuned model.") + lora_ext = gr.Textbox(value="*.safetensors *.pt", visible=False) + lora_ext_name = gr.Textbox(value="LoRA model types", visible=False) def list_models(path): nonlocal current_model_dir @@ -96,17 +95,22 @@ def list_save_to(path): with gr.Group(), gr.Row(): model = gr.Dropdown( - label='DyLoRA model (path to the DyLoRA model to extract from)', + label="DyLoRA model (path to the DyLoRA model to extract from)", interactive=True, choices=[""] + list_models(current_model_dir), value="", allow_custom_value=True, ) - create_refresh_button(model, lambda: None, lambda: {"choices": list_models(current_model_dir)}, "open_folder_small") + create_refresh_button( + model, + lambda: None, + lambda: {"choices": list_models(current_model_dir)}, + "open_folder_small", + ) button_model_file = gr.Button( folder_symbol, - elem_id='open_folder_small', - elem_classes=['tool'], + elem_id="open_folder_small", + elem_classes=["tool"], visible=(not headless), ) button_model_file.click( @@ -117,17 +121,22 @@ def list_save_to(path): ) save_to = gr.Dropdown( - label='Save to (path where to save the extracted LoRA model...)', + label="Save to (path where to save the extracted LoRA model...)", interactive=True, choices=[""] + list_save_to(current_save_dir), value="", allow_custom_value=True, ) - create_refresh_button(save_to, lambda: None, lambda: {"choices": list_save_to(current_save_dir)}, "open_folder_small") + create_refresh_button( + save_to, + lambda: None, + lambda: {"choices": list_save_to(current_save_dir)}, + "open_folder_small", + ) unit = gr.Slider( minimum=1, maximum=256, - label='Network Dimension (Rank)', + label="Network Dimension (Rank)", value=1, step=1, interactive=True, @@ -146,7 +155,7 @@ def list_save_to(path): show_progress=False, ) - extract_button = gr.Button('Extract LoRA model') + extract_button = gr.Button("Extract LoRA model") extract_button.click( extract_dylora, diff --git a/kohya_gui/extract_lora_gui.py b/kohya_gui/extract_lora_gui.py index f0d0aa4fa..66c1e6123 100644 --- a/kohya_gui/extract_lora_gui.py +++ b/kohya_gui/extract_lora_gui.py @@ -1,5 +1,4 @@ import gradio as gr -from easygui import msgbox import subprocess import os import sys @@ -17,10 +16,10 @@ # Set up logging log = setup_logging() -folder_symbol = '\U0001f4c2' # 📂 -refresh_symbol = '\U0001f504' # 🔄 -save_style_symbol = '\U0001f4be' # 💾 -document_symbol = '\U0001F4C4' # 📄 +folder_symbol = "\U0001f4c2" # 📂 +refresh_symbol = "\U0001f504" # 🔄 +save_style_symbol = "\U0001f4be" # 💾 +document_symbol = "\U0001F4C4" # 📄 PYTHON = sys.executable @@ -42,21 +41,21 @@ def extract_lora( load_precision, ): # Check for caption_text_input - if model_tuned == '': - log.info('Invalid finetuned model file') + if model_tuned == "": + log.info("Invalid finetuned model file") return - if model_org == '': - log.info('Invalid base model file') + if model_org == "": + log.info("Invalid base model file") return # Check if source model exist if not os.path.isfile(model_tuned): - log.info('The provided finetuned model is not a file') + log.info("The provided finetuned model is not a file") return if not os.path.isfile(model_org): - log.info('The provided base model is not a file') + log.info("The provided base model is not a file") return if os.path.dirname(save_to) == "": @@ -74,31 +73,33 @@ def extract_lora( return run_cmd = ( - fr'"{PYTHON}" "{scriptdir}/sd-scripts/networks/extract_lora_from_models.py"' + rf'"{PYTHON}" "{scriptdir}/sd-scripts/networks/extract_lora_from_models.py"' ) - run_cmd += f' --load_precision {load_precision}' - run_cmd += f' --save_precision {save_precision}' - run_cmd += fr' --save_to "{save_to}"' - run_cmd += fr' --model_org "{model_org}"' - run_cmd += fr' --model_tuned "{model_tuned}"' - run_cmd += f' --dim {dim}' - run_cmd += f' --device {device}' + run_cmd += f" --load_precision {load_precision}" + run_cmd += f" --save_precision {save_precision}" + run_cmd += rf' --save_to "{save_to}"' + run_cmd += rf' --model_org "{model_org}"' + run_cmd += rf' --model_tuned "{model_tuned}"' + run_cmd += f" --dim {dim}" + run_cmd += f" --device {device}" if conv_dim > 0: - run_cmd += f' --conv_dim {conv_dim}' + run_cmd += f" --conv_dim {conv_dim}" if v2: - run_cmd += f' --v2' + run_cmd += f" --v2" if sdxl: - run_cmd += f' --sdxl' - run_cmd += f' --clamp_quantile {clamp_quantile}' - run_cmd += f' --min_diff {min_diff}' + run_cmd += f" --sdxl" + run_cmd += f" --clamp_quantile {clamp_quantile}" + run_cmd += f" --min_diff {min_diff}" if sdxl: - run_cmd += f' --load_original_model_to {load_original_model_to}' - run_cmd += f' --load_tuned_model_to {load_tuned_model_to}' + run_cmd += f" --load_original_model_to {load_original_model_to}" + run_cmd += f" --load_tuned_model_to {load_tuned_model_to}" log.info(run_cmd) env = os.environ.copy() - env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + env["PYTHONPATH"] = ( + rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + ) # Run the command subprocess.run(run_cmd, shell=True, env=env) @@ -132,29 +133,31 @@ def list_save_to(path): def change_sdxl(sdxl): return gr.Dropdown(visible=sdxl), gr.Dropdown(visible=sdxl) - - with gr.Tab('Extract LoRA'): - gr.Markdown( - 'This utility can extract a LoRA network from a finetuned model.' - ) - lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False) - lora_ext_name = gr.Textbox(value='LoRA model types', visible=False) - model_ext = gr.Textbox(value='*.ckpt *.safetensors', visible=False) - model_ext_name = gr.Textbox(value='Model types', visible=False) + with gr.Tab("Extract LoRA"): + gr.Markdown("This utility can extract a LoRA network from a finetuned model.") + lora_ext = gr.Textbox(value="*.safetensors *.pt", visible=False) + lora_ext_name = gr.Textbox(value="LoRA model types", visible=False) + model_ext = gr.Textbox(value="*.ckpt *.safetensors", visible=False) + model_ext_name = gr.Textbox(value="Model types", visible=False) with gr.Group(), gr.Row(): model_tuned = gr.Dropdown( - label='Finetuned model (path to the finetuned model to extract)', + label="Finetuned model (path to the finetuned model to extract)", interactive=True, choices=[""] + list_models(current_model_dir), value="", allow_custom_value=True, ) - create_refresh_button(model_tuned, lambda: None, lambda: {"choices": list_models(current_model_dir)}, "open_folder_small") + create_refresh_button( + model_tuned, + lambda: None, + lambda: {"choices": list_models(current_model_dir)}, + "open_folder_small", + ) button_model_tuned_file = gr.Button( folder_symbol, - elem_id='open_folder_small', - elem_classes=['tool'], + elem_id="open_folder_small", + elem_classes=["tool"], visible=(not headless), ) button_model_tuned_file.click( @@ -164,25 +167,31 @@ def change_sdxl(sdxl): show_progress=False, ) load_tuned_model_to = gr.Radio( - label='Load finetuned model to', - choices=['cpu', 'cuda', 'cuda:0'], - value='cpu', - interactive=True, scale=1, + label="Load finetuned model to", + choices=["cpu", "cuda", "cuda:0"], + value="cpu", + interactive=True, + scale=1, info="only for SDXL", visible=False, ) model_org = gr.Dropdown( - label='Stable Diffusion base model (original model: ckpt or safetensors file)', + label="Stable Diffusion base model (original model: ckpt or safetensors file)", interactive=True, choices=[""] + list_org_models(current_model_org_dir), value="", allow_custom_value=True, ) - create_refresh_button(model_org, lambda: None, lambda: {"choices": list_org_models(current_model_org_dir)}, "open_folder_small") + create_refresh_button( + model_org, + lambda: None, + lambda: {"choices": list_org_models(current_model_org_dir)}, + "open_folder_small", + ) button_model_org_file = gr.Button( folder_symbol, - elem_id='open_folder_small', - elem_classes=['tool'], + elem_id="open_folder_small", + elem_classes=["tool"], visible=(not headless), ) button_model_org_file.click( @@ -192,27 +201,33 @@ def change_sdxl(sdxl): show_progress=False, ) load_original_model_to = gr.Dropdown( - label='Load Stable Diffusion base model to', - choices=['cpu', 'cuda', 'cuda:0'], - value='cpu', - interactive=True, scale=1, + label="Load Stable Diffusion base model to", + choices=["cpu", "cuda", "cuda:0"], + value="cpu", + interactive=True, + scale=1, info="only for SDXL", visible=False, ) with gr.Group(), gr.Row(): save_to = gr.Dropdown( - label='Save to (path where to save the extracted LoRA model...)', + label="Save to (path where to save the extracted LoRA model...)", interactive=True, choices=[""] + list_save_to(current_save_dir), value="", allow_custom_value=True, scale=2, ) - create_refresh_button(save_to, lambda: None, lambda: {"choices": list_save_to(current_save_dir)}, "open_folder_small") + create_refresh_button( + save_to, + lambda: None, + lambda: {"choices": list_save_to(current_save_dir)}, + "open_folder_small", + ) button_save_to = gr.Button( folder_symbol, - elem_id='open_folder_small', - elem_classes=['tool'], + elem_id="open_folder_small", + elem_classes=["tool"], visible=(not headless), ) button_save_to.click( @@ -222,16 +237,18 @@ def change_sdxl(sdxl): show_progress=False, ) save_precision = gr.Radio( - label='Save precision', - choices=['fp16', 'bf16', 'float'], - value='fp16', - interactive=True, scale=1, + label="Save precision", + choices=["fp16", "bf16", "float"], + value="fp16", + interactive=True, + scale=1, ) load_precision = gr.Radio( - label='Load precision', - choices=['fp16', 'bf16', 'float'], - value='fp16', - interactive=True, scale=1, + label="Load precision", + choices=["fp16", "bf16", "float"], + value="fp16", + interactive=True, + scale=1, ) model_tuned.change( @@ -256,7 +273,7 @@ def change_sdxl(sdxl): dim = gr.Slider( minimum=4, maximum=1024, - label='Network Dimension (Rank)', + label="Network Dimension (Rank)", value=128, step=1, interactive=True, @@ -264,13 +281,13 @@ def change_sdxl(sdxl): conv_dim = gr.Slider( minimum=0, maximum=1024, - label='Conv Dimension (Rank)', + label="Conv Dimension (Rank)", value=128, step=1, interactive=True, ) clamp_quantile = gr.Number( - label='Clamp Quantile', + label="Clamp Quantile", value=0.99, minimum=0, maximum=1, @@ -278,7 +295,7 @@ def change_sdxl(sdxl): interactive=True, ) min_diff = gr.Number( - label='Minimum difference', + label="Minimum difference", value=0.01, minimum=0, maximum=1, @@ -286,21 +303,25 @@ def change_sdxl(sdxl): interactive=True, ) with gr.Row(): - v2 = gr.Checkbox(label='v2', value=False, interactive=True) - sdxl = gr.Checkbox(label='SDXL', value=False, interactive=True) + v2 = gr.Checkbox(label="v2", value=False, interactive=True) + sdxl = gr.Checkbox(label="SDXL", value=False, interactive=True) device = gr.Radio( - label='Device', + label="Device", choices=[ - 'cpu', - 'cuda', + "cpu", + "cuda", ], - value='cuda', + value="cuda", interactive=True, ) - - sdxl.change(change_sdxl, inputs=sdxl, outputs=[load_tuned_model_to, load_original_model_to]) - extract_button = gr.Button('Extract LoRA model') + sdxl.change( + change_sdxl, + inputs=sdxl, + outputs=[load_tuned_model_to, load_original_model_to], + ) + + extract_button = gr.Button("Extract LoRA model") extract_button.click( extract_lora, diff --git a/kohya_gui/extract_lycoris_locon_gui.py b/kohya_gui/extract_lycoris_locon_gui.py index 6524dcadf..4ae331579 100644 --- a/kohya_gui/extract_lycoris_locon_gui.py +++ b/kohya_gui/extract_lycoris_locon_gui.py @@ -5,7 +5,6 @@ import sys from .common_gui import ( get_saveasfilename_path, - get_any_file_path, get_file_path, scriptdir, list_files, @@ -74,7 +73,7 @@ def extract_lycoris_locon( path, ext = os.path.splitext(output_name) output_name = f"{path}_tmp{ext}" - run_cmd = fr'"{PYTHON}" "{scriptdir}/tools/lycoris_locon_extract.py"' + run_cmd = rf'"{PYTHON}" "{scriptdir}/tools/lycoris_locon_extract.py"' if is_sdxl: run_cmd += f" --is_sdxl" if is_v2: @@ -99,19 +98,21 @@ def extract_lycoris_locon( run_cmd += f" --sparsity {sparsity}" if disable_cp: run_cmd += f" --disable_cp" - run_cmd += fr' "{base_model}"' - run_cmd += fr' "{db_model}"' - run_cmd += fr' "{output_name}"' + run_cmd += rf' "{base_model}"' + run_cmd += rf' "{db_model}"' + run_cmd += rf' "{output_name}"' log.info(run_cmd) env = os.environ.copy() - env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + env["PYTHONPATH"] = ( + rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + ) # Run the command subprocess.run(run_cmd, shell=True, env=env) - log.info('Done extracting...') + log.info("Done extracting...") ### @@ -185,11 +186,16 @@ def list_save_to(path): value="", allow_custom_value=True, ) - create_refresh_button(db_model, lambda: None, lambda: {"choices": list_models(current_model_dir)}, "open_folder_small") + create_refresh_button( + db_model, + lambda: None, + lambda: {"choices": list_models(current_model_dir)}, + "open_folder_small", + ) button_db_model_file = gr.Button( folder_symbol, elem_id="open_folder_small", - elem_classes=['tool'], + elem_classes=["tool"], visible=(not headless), ) button_db_model_file.click( @@ -205,11 +211,16 @@ def list_save_to(path): value="", allow_custom_value=True, ) - create_refresh_button(base_model, lambda: None, lambda: {"choices": list_base_models(current_base_model_dir)}, "open_folder_small") + create_refresh_button( + base_model, + lambda: None, + lambda: {"choices": list_base_models(current_base_model_dir)}, + "open_folder_small", + ) button_base_model_file = gr.Button( folder_symbol, elem_id="open_folder_small", - elem_classes=['tool'], + elem_classes=["tool"], visible=(not headless), ) button_base_model_file.click( @@ -227,11 +238,16 @@ def list_save_to(path): allow_custom_value=True, scale=2, ) - create_refresh_button(output_name, lambda: None, lambda: {"choices": list_save_to(current_save_dir)}, "open_folder_small") + create_refresh_button( + output_name, + lambda: None, + lambda: {"choices": list_save_to(current_save_dir)}, + "open_folder_small", + ) button_output_name = gr.Button( folder_symbol, elem_id="open_folder_small", - elem_classes=['tool'], + elem_classes=["tool"], visible=(not headless), ) button_output_name.click( @@ -270,7 +286,9 @@ def list_save_to(path): show_progress=False, ) - is_sdxl = gr.Checkbox(label="is SDXL", value=False, interactive=True, scale=1) + is_sdxl = gr.Checkbox( + label="is SDXL", value=False, interactive=True, scale=1 + ) is_v2 = gr.Checkbox(label="is v2", value=False, interactive=True, scale=1) with gr.Row(): diff --git a/kohya_gui/finetune_gui.py b/kohya_gui/finetune_gui.py index 9fc49e34e..59d48b0f8 100644 --- a/kohya_gui/finetune_gui.py +++ b/kohya_gui/finetune_gui.py @@ -4,12 +4,10 @@ import os import subprocess import sys -import pathlib from datetime import datetime from .common_gui import ( get_file_path, get_saveasfile_path, - save_inference_file, run_cmd_advanced_training, color_aug_changed, update_my_data, @@ -19,6 +17,7 @@ scriptdir, validate_paths, ) +from .class_accelerate_launch import AccelerateLaunch from .class_configuration_file import ConfigurationFile from .class_source_model import SourceModel from .class_basic_training import BasicTraining @@ -50,7 +49,8 @@ PYTHON = sys.executable -presets_dir = fr'{scriptdir}/presets' +presets_dir = rf"{scriptdir}/presets" + def save_configuration( save_as, @@ -69,6 +69,7 @@ def save_configuration( max_bucket_reso, batch_size, flip_aug, + masked_loss, caption_metadata_filename, latent_metadata_filename, full_path, @@ -99,7 +100,9 @@ def save_configuration( num_machines, multi_gpu, gpu_ids, + main_process_port, save_state, + save_state_on_train_end, resume, gradient_checkpointing, gradient_accumulation_steps, @@ -130,9 +133,12 @@ def save_configuration( lr_scheduler_args, noise_offset_type, noise_offset, + noise_offset_random_strength, adaptive_noise_scale, multires_noise_iterations, multires_noise_discount, + ip_noise_gamma, + ip_noise_gamma_random_strength, sample_every_n_steps, sample_every_n_epochs, sample_sampler, @@ -154,6 +160,7 @@ def save_configuration( sdxl_no_half_vae, min_timestep, max_timestep, + extra_accelerate_launch_args, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -209,6 +216,7 @@ def open_configuration( max_bucket_reso, batch_size, flip_aug, + masked_loss, caption_metadata_filename, latent_metadata_filename, full_path, @@ -239,7 +247,9 @@ def open_configuration( num_machines, multi_gpu, gpu_ids, + main_process_port, save_state, + save_state_on_train_end, resume, gradient_checkpointing, gradient_accumulation_steps, @@ -270,9 +280,12 @@ def open_configuration( lr_scheduler_args, noise_offset_type, noise_offset, + noise_offset_random_strength, adaptive_noise_scale, multires_noise_iterations, multires_noise_discount, + ip_noise_gamma, + ip_noise_gamma_random_strength, sample_every_n_steps, sample_every_n_epochs, sample_sampler, @@ -294,6 +307,7 @@ def open_configuration( sdxl_no_half_vae, min_timestep, max_timestep, + extra_accelerate_launch_args, training_preset, ): # Get list of function parameters and values @@ -305,7 +319,7 @@ def open_configuration( # Check if we are "applying" a preset or a config if apply_preset: log.info(f"Applying preset {training_preset}...") - file_path = fr'{presets_dir}/finetune/{training_preset}.json' + file_path = rf"{presets_dir}/finetune/{training_preset}.json" else: # If not applying a preset, set the `training_preset` field to an empty string # Find the index of the `training_preset` parameter using the `index()` method @@ -356,6 +370,7 @@ def train_model( max_bucket_reso, batch_size, flip_aug, + masked_loss, caption_metadata_filename, latent_metadata_filename, full_path, @@ -386,7 +401,9 @@ def train_model( num_machines, multi_gpu, gpu_ids, + main_process_port, save_state, + save_state_on_train_end, resume, gradient_checkpointing, gradient_accumulation_steps, @@ -417,9 +434,12 @@ def train_model( lr_scheduler_args, noise_offset_type, noise_offset, + noise_offset_random_strength, adaptive_noise_scale, multires_noise_iterations, multires_noise_discount, + ip_noise_gamma, + ip_noise_gamma_random_strength, sample_every_n_steps, sample_every_n_epochs, sample_sampler, @@ -441,6 +461,7 @@ def train_model( sdxl_no_half_vae, min_timestep, max_timestep, + extra_accelerate_launch_args, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -461,32 +482,38 @@ def train_model( logging_dir=logging_dir, log_tracker_config=log_tracker_config, resume=resume, - dataset_config=dataset_config + dataset_config=dataset_config, ): return - if not print_only_bool and check_if_model_exist(output_name, output_dir, save_model_as, headless_bool): + if not print_only_bool and check_if_model_exist( + output_name, output_dir, save_model_as, headless_bool + ): return if dataset_config: - log.info("Dataset config toml file used, skipping caption json file, image buckets, total_steps, train_batch_size, gradient_accumulation_steps, epoch, reg_factor, max_train_steps creation...") + log.info( + "Dataset config toml file used, skipping caption json file, image buckets, total_steps, train_batch_size, gradient_accumulation_steps, epoch, reg_factor, max_train_steps creation..." + ) else: # create caption json file if generate_caption_database: - run_cmd = fr'"{PYTHON}" "{scriptdir}/sd-scripts/finetune/merge_captions_to_metadata.py"' + run_cmd = rf'"{PYTHON}" "{scriptdir}/sd-scripts/finetune/merge_captions_to_metadata.py"' if caption_extension == "": run_cmd += f' --caption_extension=".caption"' else: run_cmd += f" --caption_extension={caption_extension}" - run_cmd += fr' "{image_folder}"' - run_cmd += fr' "{train_dir}/{caption_metadata_filename}"' + run_cmd += rf' "{image_folder}"' + run_cmd += rf' "{train_dir}/{caption_metadata_filename}"' if full_path: run_cmd += f" --full_path" log.info(run_cmd) env = os.environ.copy() - env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + env["PYTHONPATH"] = ( + rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + ) if not print_only_bool: # Run the command @@ -494,11 +521,11 @@ def train_model( # create images buckets if generate_image_buckets: - run_cmd = fr'"{PYTHON}" "{scriptdir}/sd-scripts/finetune/prepare_buckets_latents.py"' - run_cmd += fr' "{image_folder}"' - run_cmd += fr' "{train_dir}/{caption_metadata_filename}"' - run_cmd += fr' "{train_dir}/{latent_metadata_filename}"' - run_cmd += fr' "{pretrained_model_name_or_path}"' + run_cmd = rf'"{PYTHON}" "{scriptdir}/sd-scripts/finetune/prepare_buckets_latents.py"' + run_cmd += rf' "{image_folder}"' + run_cmd += rf' "{train_dir}/{caption_metadata_filename}"' + run_cmd += rf' "{train_dir}/{latent_metadata_filename}"' + run_cmd += rf' "{pretrained_model_name_or_path}"' run_cmd += f" --batch_size={batch_size}" run_cmd += f" --max_resolution={max_resolution}" run_cmd += f" --min_bucket_reso={min_bucket_reso}" @@ -509,13 +536,17 @@ def train_model( if full_path: run_cmd += f" --full_path" if sdxl_checkbox and sdxl_no_half_vae: - log.info("Using mixed_precision = no because no half vae is selected...") + log.info( + "Using mixed_precision = no because no half vae is selected..." + ) run_cmd += f' --mixed_precision="no"' log.info(run_cmd) env = os.environ.copy() - env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + env["PYTHONPATH"] = ( + rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + ) if not print_only_bool: # Run the command @@ -558,23 +589,26 @@ def train_model( run_cmd = "accelerate launch" - run_cmd += run_cmd_advanced_training( + run_cmd += AccelerateLaunch.run_cmd( num_processes=num_processes, num_machines=num_machines, multi_gpu=multi_gpu, gpu_ids=gpu_ids, + main_process_port=main_process_port, num_cpu_threads_per_process=num_cpu_threads_per_process, + mixed_precision=mixed_precision, + extra_accelerate_launch_args=extra_accelerate_launch_args, ) if sdxl_checkbox: - run_cmd += fr' "{scriptdir}/sd-scripts/sdxl_train.py"' + run_cmd += rf' "{scriptdir}/sd-scripts/sdxl_train.py"' else: - run_cmd += fr' "{scriptdir}/sd-scripts/fine_tune.py"' + run_cmd += rf' "{scriptdir}/sd-scripts/fine_tune.py"' in_json = ( - fr"{train_dir}/{latent_metadata_filename}" + rf"{train_dir}/{latent_metadata_filename}" if use_latent_files == "Yes" - else fr"{train_dir}/{caption_metadata_filename}" + else rf"{train_dir}/{caption_metadata_filename}" ) cache_text_encoder_outputs = sdxl_checkbox and sdxl_cache_text_encoder_outputs no_half_vae = sdxl_checkbox and sdxl_no_half_vae @@ -596,11 +630,14 @@ def train_model( "dataset_repeats": dataset_repeats, "enable_bucket": True, "flip_aug": flip_aug, + "masked_loss": masked_loss, "full_bf16": full_bf16, "full_fp16": full_fp16, "gradient_accumulation_steps": gradient_accumulation_steps, "gradient_checkpointing": gradient_checkpointing, "in_json": in_json, + "ip_noise_gamma": ip_noise_gamma, + "ip_noise_gamma_random_strength": ip_noise_gamma_random_strength, "keep_tokens": keep_tokens, "learning_rate": learning_rate, "logging_dir": logging_dir, @@ -624,6 +661,7 @@ def train_model( "multires_noise_discount": multires_noise_discount, "multires_noise_iterations": multires_noise_iterations, "noise_offset": noise_offset, + "noise_offset_random_strength": noise_offset_random_strength, "noise_offset_type": noise_offset_type, "optimizer": optimizer, "optimizer_args": optimizer_args, @@ -640,6 +678,7 @@ def train_model( "save_model_as": save_model_as, "save_precision": save_precision, "save_state": save_state, + "save_state_on_train_end": save_state_on_train_end, "scale_v_pred_loss_like_noise_pred": scale_v_pred_loss_like_noise_pred, "seed": seed, "shuffle_caption": shuffle_caption, @@ -703,18 +742,14 @@ def train_model( log.info(run_cmd) env = os.environ.copy() - env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + env["PYTHONPATH"] = ( + rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + ) + env["TF_ENABLE_ONEDNN_OPTS"] = "0" # Run the command executor.execute_command(run_cmd=run_cmd, env=env) - # # check if output_dir/last is a folder... therefore it is a diffuser model - # last_dir = pathlib.Path(f"{output_dir}/{output_name}") - - # if not last_dir.is_dir(): - # # Copy inference model for v2 if required - # save_inference_file(output_dir, v2, v_parameterization, output_name) - def finetune_tab(headless=False, config: dict = {}): dummy_db_true = gr.Label(value=True, visible=False) @@ -723,8 +758,13 @@ def finetune_tab(headless=False, config: dict = {}): with gr.Tab("Training"), gr.Column(variant="compact"): gr.Markdown("Train a custom model using kohya finetune python code...") + with gr.Accordion("Accelerate launch", open=False), gr.Column(): + accelerate_launch = AccelerateLaunch(config=config) + with gr.Column(): - source_model = SourceModel(headless=headless, finetuning=True, config=config) + source_model = SourceModel( + headless=headless, finetuning=True, config=config + ) image_folder = source_model.train_data_dir output_name = source_model.output_name @@ -758,33 +798,38 @@ def list_presets(path): elem_id="myDropdown", ) - with gr.Group(elem_id="basic_tab"): - basic_training = BasicTraining( - learning_rate_value="1e-5", - finetuning=True, - sdxl_checkbox=source_model.sdxl_checkbox, - ) + with gr.Accordion("Basic", open="True"): + with gr.Group(elem_id="basic_tab"): + basic_training = BasicTraining( + learning_rate_value="1e-5", + finetuning=True, + sdxl_checkbox=source_model.sdxl_checkbox, + config=config, + ) - # Add SDXL Parameters - sdxl_params = SDXLParameters(source_model.sdxl_checkbox) + # Add SDXL Parameters + sdxl_params = SDXLParameters(source_model.sdxl_checkbox, config=config) - with gr.Row(): - dataset_repeats = gr.Textbox(label="Dataset repeats", value=40) - train_text_encoder = gr.Checkbox( - label="Train text encoder", value=True - ) + with gr.Row(): + dataset_repeats = gr.Textbox(label="Dataset repeats", value=40) + train_text_encoder = gr.Checkbox( + label="Train text encoder", value=True + ) with gr.Accordion("Advanced", open=False, elem_id="advanced_tab"): with gr.Row(): gradient_accumulation_steps = gr.Number( - label="Gradient accumulate steps", value="1", + label="Gradient accumulate steps", + value="1", ) block_lr = gr.Textbox( label="Block LR (SDXL)", placeholder="(Optional)", info="Specify the different learning rates for each U-Net block. Specify 23 values separated by commas like 1e-3,1e-3 ... 1e-3", ) - advanced_training = AdvancedTraining(headless=headless, finetuning=True, config=config) + advanced_training = AdvancedTraining( + headless=headless, finetuning=True, config=config + ) advanced_training.color_aug.change( color_aug_changed, inputs=[advanced_training.color_aug], @@ -794,7 +839,7 @@ def list_presets(path): ) with gr.Accordion("Samples", open=False, elem_id="samples_tab"): - sample = SampleImages() + sample = SampleImages(config=config) with gr.Accordion("Dataset Preparation", open=False): with gr.Row(): @@ -840,7 +885,6 @@ def list_presets(path): with gr.Accordion("Configuration", open=False): configuration = ConfigurationFile(headless=headless, config=config) - with gr.Column(), gr.Group(): with gr.Row(): button_run = gr.Button("Start training", variant="primary") @@ -881,6 +925,7 @@ def list_presets(path): max_bucket_reso, batch_size, advanced_training.flip_aug, + advanced_training.masked_loss, caption_metadata_filename, latent_metadata_filename, full_path, @@ -891,10 +936,10 @@ def list_presets(path): basic_training.train_batch_size, basic_training.epoch, basic_training.save_every_n_epochs, - basic_training.mixed_precision, + accelerate_launch.mixed_precision, source_model.save_precision, basic_training.seed, - basic_training.num_cpu_threads_per_process, + accelerate_launch.num_cpu_threads_per_process, basic_training.learning_rate_te, basic_training.learning_rate_te1, basic_training.learning_rate_te2, @@ -906,11 +951,13 @@ def list_presets(path): basic_training.caption_extension, advanced_training.xformers, advanced_training.clip_skip, - advanced_training.num_processes, - advanced_training.num_machines, - advanced_training.multi_gpu, - advanced_training.gpu_ids, + accelerate_launch.num_processes, + accelerate_launch.num_machines, + accelerate_launch.multi_gpu, + accelerate_launch.gpu_ids, + accelerate_launch.main_process_port, advanced_training.save_state, + advanced_training.save_state_on_train_end, advanced_training.resume, advanced_training.gradient_checkpointing, gradient_accumulation_steps, @@ -941,9 +988,12 @@ def list_presets(path): basic_training.lr_scheduler_args, advanced_training.noise_offset_type, advanced_training.noise_offset, + advanced_training.noise_offset_random_strength, advanced_training.adaptive_noise_scale, advanced_training.multires_noise_iterations, advanced_training.multires_noise_discount, + advanced_training.ip_noise_gamma, + advanced_training.ip_noise_gamma_random_strength, sample.sample_every_n_steps, sample.sample_every_n_epochs, sample.sample_sampler, @@ -965,6 +1015,7 @@ def list_presets(path): sdxl_params.sdxl_no_half_vae, advanced_training.min_timestep, advanced_training.max_timestep, + accelerate_launch.extra_accelerate_launch_args, ] configuration.button_open_config.click( @@ -972,7 +1023,9 @@ def list_presets(path): inputs=[dummy_db_true, dummy_db_false, configuration.config_file_name] + settings_list + [training_preset], - outputs=[configuration.config_file_name] + settings_list + [training_preset], + outputs=[configuration.config_file_name] + + settings_list + + [training_preset], show_progress=False, ) @@ -988,7 +1041,9 @@ def list_presets(path): inputs=[dummy_db_false, dummy_db_false, configuration.config_file_name] + settings_list + [training_preset], - outputs=[configuration.config_file_name] + settings_list + [training_preset], + outputs=[configuration.config_file_name] + + settings_list + + [training_preset], show_progress=False, ) @@ -1029,16 +1084,16 @@ def list_presets(path): show_progress=False, ) - #config.button_save_as_config.click( + # config.button_save_as_config.click( # save_configuration, # inputs=[dummy_db_true, config.config_file_name] + settings_list, # outputs=[config.config_file_name], # show_progress=False, - #) + # ) with gr.Tab("Guides"): gr.Markdown("This section provide Various Finetuning guides and information...") - top_level_path = fr"{scriptdir}/docs/Finetuning/top_level.md" + top_level_path = rf"{scriptdir}/docs/Finetuning/top_level.md" if os.path.exists(top_level_path): with open(os.path.join(top_level_path), "r", encoding="utf8") as file: guides_top_level = file.read() + "\n" diff --git a/kohya_gui/git_caption_gui.py b/kohya_gui/git_caption_gui.py index 171b8c78b..a98449749 100644 --- a/kohya_gui/git_caption_gui.py +++ b/kohya_gui/git_caption_gui.py @@ -24,31 +24,31 @@ def caption_images( postfix, ): # Check for images_dir_input - if train_data_dir == '': - msgbox('Image folder is missing...') + if train_data_dir == "": + msgbox("Image folder is missing...") return - if caption_ext == '': - msgbox('Please provide an extension for the caption files.') + if caption_ext == "": + msgbox("Please provide an extension for the caption files.") return - log.info(f'GIT captioning files in {train_data_dir}...') - run_cmd = fr'"{PYTHON}" "{scriptdir}/sd-scripts/finetune/make_captions_by_git.py"' - if not model_id == '': + log.info(f"GIT captioning files in {train_data_dir}...") + run_cmd = rf'"{PYTHON}" "{scriptdir}/sd-scripts/finetune/make_captions_by_git.py"' + if not model_id == "": run_cmd += f' --model_id="{model_id}"' run_cmd += f' --batch_size="{int(batch_size)}"' - run_cmd += ( - f' --max_data_loader_n_workers="{int(max_data_loader_n_workers)}"' - ) + run_cmd += f' --max_data_loader_n_workers="{int(max_data_loader_n_workers)}"' run_cmd += f' --max_length="{int(max_length)}"' - if caption_ext != '': + if caption_ext != "": run_cmd += f' --caption_extension="{caption_ext}"' run_cmd += f' "{train_data_dir}"' log.info(run_cmd) env = os.environ.copy() - env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + env["PYTHONPATH"] = ( + rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + ) # Run the command subprocess.run(run_cmd, shell=True, env=env) @@ -61,7 +61,7 @@ def caption_images( postfix=postfix, ) - log.info('...captioning done') + log.info("...captioning done") ### @@ -72,7 +72,11 @@ def caption_images( def gradio_git_caption_gui_tab(headless=False, default_train_dir=None): from .common_gui import create_refresh_button - default_train_dir = default_train_dir if default_train_dir is not None else os.path.join(scriptdir, "data") + default_train_dir = ( + default_train_dir + if default_train_dir is not None + else os.path.join(scriptdir, "data") + ) current_train_dir = default_train_dir def list_train_dirs(path): @@ -80,21 +84,29 @@ def list_train_dirs(path): current_train_dir = path return list(list_dirs(path)) - with gr.Tab('GIT Captioning'): + with gr.Tab("GIT Captioning"): gr.Markdown( - 'This utility will use GIT to caption files for each images in a folder.' + "This utility will use GIT to caption files for each images in a folder." ) with gr.Group(), gr.Row(): train_data_dir = gr.Dropdown( - label='Image folder to caption (containing the images to caption)', + label="Image folder to caption (containing the images to caption)", choices=[""] + list_train_dirs(default_train_dir), value="", interactive=True, allow_custom_value=True, ) - create_refresh_button(train_data_dir, lambda: None, lambda: {"choices": list_train_dirs(current_train_dir)},"open_folder_small") + create_refresh_button( + train_data_dir, + lambda: None, + lambda: {"choices": list_train_dirs(current_train_dir)}, + "open_folder_small", + ) button_train_data_dir_input = gr.Button( - '📂', elem_id='open_folder_small', elem_classes=['tool'], visible=(not headless) + "📂", + elem_id="open_folder_small", + elem_classes=["tool"], + visible=(not headless), ) button_train_data_dir_input.click( get_folder_path, @@ -103,42 +115,38 @@ def list_train_dirs(path): ) with gr.Row(): caption_ext = gr.Textbox( - label='Caption file extension', - placeholder='Extension for caption file (e.g., .caption, .txt)', - value='.txt', + label="Caption file extension", + placeholder="Extension for caption file (e.g., .caption, .txt)", + value=".txt", interactive=True, ) prefix = gr.Textbox( - label='Prefix to add to GIT caption', - placeholder='(Optional)', + label="Prefix to add to GIT caption", + placeholder="(Optional)", interactive=True, ) postfix = gr.Textbox( - label='Postfix to add to GIT caption', - placeholder='(Optional)', + label="Postfix to add to GIT caption", + placeholder="(Optional)", interactive=True, ) - batch_size = gr.Number( - value=1, label='Batch size', interactive=True - ) + batch_size = gr.Number(value=1, label="Batch size", interactive=True) with gr.Row(): max_data_loader_n_workers = gr.Number( - value=2, label='Number of workers', interactive=True - ) - max_length = gr.Number( - value=75, label='Max length', interactive=True + value=2, label="Number of workers", interactive=True ) + max_length = gr.Number(value=75, label="Max length", interactive=True) model_id = gr.Textbox( - label='Model', - placeholder='(Optional) model id for GIT in Hugging Face', + label="Model", + placeholder="(Optional) model id for GIT in Hugging Face", interactive=True, ) - caption_button = gr.Button('Caption images') + caption_button = gr.Button("Caption images") caption_button.click( caption_images, diff --git a/kohya_gui/group_images_gui.py b/kohya_gui/group_images_gui.py index 86854fb40..bbd78d852 100644 --- a/kohya_gui/group_images_gui.py +++ b/kohya_gui/group_images_gui.py @@ -22,38 +22,40 @@ def group_images( generate_captions, caption_ext, ): - if input_folder == '': - msgbox('Input folder is missing...') + if input_folder == "": + msgbox("Input folder is missing...") return - if output_folder == '': - msgbox('Please provide an output folder.') + if output_folder == "": + msgbox("Please provide an output folder.") return - log.info(f'Grouping images in {input_folder}...') + log.info(f"Grouping images in {input_folder}...") - run_cmd = fr'"{PYTHON}" "{scriptdir}/tools/group_images.py"' + run_cmd = rf'"{PYTHON}" "{scriptdir}/tools/group_images.py"' run_cmd += f' "{input_folder}"' run_cmd += f' "{output_folder}"' - run_cmd += f' {(group_size)}' + run_cmd += f" {(group_size)}" if include_subfolders: - run_cmd += f' --include_subfolders' + run_cmd += f" --include_subfolders" if do_not_copy_other_files: - run_cmd += f' --do_not_copy_other_files' + run_cmd += f" --do_not_copy_other_files" if generate_captions: - run_cmd += f' --caption' + run_cmd += f" --caption" if caption_ext: - run_cmd += f' --caption_ext={caption_ext}' + run_cmd += f" --caption_ext={caption_ext}" log.info(run_cmd) env = os.environ.copy() - env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + env["PYTHONPATH"] = ( + rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + ) # Run the command subprocess.run(run_cmd, shell=True, env=env) - log.info('...grouping done') + log.info("...grouping done") def gradio_group_images_gui_tab(headless=False): @@ -72,22 +74,30 @@ def list_output_dirs(path): current_output_folder = path return list(list_dirs(path)) - with gr.Tab('Group Images'): + with gr.Tab("Group Images"): gr.Markdown( - 'This utility will group images in a folder based on their aspect ratio.' + "This utility will group images in a folder based on their aspect ratio." ) with gr.Group(), gr.Row(): input_folder = gr.Dropdown( - label='Input folder (containing the images to group)', + label="Input folder (containing the images to group)", interactive=True, choices=[""] + list_input_dirs(current_input_folder), value="", allow_custom_value=True, ) - create_refresh_button(input_folder, lambda: None, lambda: {"choices": list_input_dirs(current_input_folder)},"open_folder_small") + create_refresh_button( + input_folder, + lambda: None, + lambda: {"choices": list_input_dirs(current_input_folder)}, + "open_folder_small", + ) button_input_folder = gr.Button( - '📂', elem_id='open_folder_small', elem_classes=['tool'], visible=(not headless) + "📂", + elem_id="open_folder_small", + elem_classes=["tool"], + visible=(not headless), ) button_input_folder.click( get_folder_path, @@ -96,15 +106,23 @@ def list_output_dirs(path): ) output_folder = gr.Dropdown( - label='Output folder (where the grouped images will be stored)', + label="Output folder (where the grouped images will be stored)", interactive=True, choices=[""] + list_output_dirs(current_output_folder), value="", allow_custom_value=True, ) - create_refresh_button(output_folder, lambda: None, lambda: {"choices": list_output_dirs(current_output_folder)},"open_folder_small") + create_refresh_button( + output_folder, + lambda: None, + lambda: {"choices": list_output_dirs(current_output_folder)}, + "open_folder_small", + ) button_output_folder = gr.Button( - '📂', elem_id='open_folder_small', elem_classes=['tool'], visible=(not headless) + "📂", + elem_id="open_folder_small", + elem_classes=["tool"], + visible=(not headless), ) button_output_folder.click( get_folder_path, @@ -126,9 +144,9 @@ def list_output_dirs(path): ) with gr.Row(): group_size = gr.Slider( - label='Group size', - info='Number of images to group together', - value='4', + label="Group size", + info="Number of images to group together", + value="4", minimum=1, maximum=64, step=1, @@ -136,31 +154,31 @@ def list_output_dirs(path): ) include_subfolders = gr.Checkbox( - label='Include Subfolders', + label="Include Subfolders", value=False, - info='Include images in subfolders as well', + info="Include images in subfolders as well", ) do_not_copy_other_files = gr.Checkbox( - label='Do not copy other files', + label="Do not copy other files", value=False, - info='Do not copy other files in the input folder to the output folder', + info="Do not copy other files in the input folder to the output folder", ) generate_captions = gr.Checkbox( - label='Generate Captions', + label="Generate Captions", value=False, - info='Generate caption files for the grouped images based on their folder name', + info="Generate caption files for the grouped images based on their folder name", ) caption_ext = gr.Textbox( - label='Caption Extension', - placeholder='Caption file extension (e.g., .txt)', - value='.txt', + label="Caption Extension", + placeholder="Caption file extension (e.g., .txt)", + value=".txt", interactive=True, ) - group_images_button = gr.Button('Group images') + group_images_button = gr.Button("Group images") group_images_button.click( group_images, diff --git a/kohya_gui/localization.py b/kohya_gui/localization.py index 4bed769bd..3cddec740 100644 --- a/kohya_gui/localization.py +++ b/kohya_gui/localization.py @@ -7,7 +7,7 @@ def load_localizations(): localizationMap.clear() - dirname = './localizations' + dirname = "./localizations" for file in os.listdir(dirname): fn, ext = os.path.splitext(file) if ext.lower() != ".json": @@ -28,4 +28,4 @@ def load_language_js(language_name: str) -> str: return f"window.localization = {json.dumps(data)}" -load_localizations() \ No newline at end of file +load_localizations() diff --git a/kohya_gui/localization_ext.py b/kohya_gui/localization_ext.py index 0f7c64653..782e26749 100644 --- a/kohya_gui/localization_ext.py +++ b/kohya_gui/localization_ext.py @@ -4,13 +4,15 @@ def file_path(fn): - return f'file={os.path.abspath(fn)}?{os.path.getmtime(fn)}' + return f"file={os.path.abspath(fn)}?{os.path.getmtime(fn)}" def js_html_str(language): head = f'\n' - head += f'\n' - head += f'\n' + head += ( + f'\n' + ) + head += f'\n' return head @@ -22,12 +24,12 @@ def add_javascript(language): def template_response(*args, **kwargs): res = localization.GrRoutesTemplateResponse(*args, **kwargs) - res.body = res.body.replace(b'', f'{jsStr}'.encode("utf8")) + res.body = res.body.replace(b"", f"{jsStr}".encode("utf8")) res.init_headers() return res gr.routes.templates.TemplateResponse = template_response -if not hasattr(localization, 'GrRoutesTemplateResponse'): +if not hasattr(localization, "GrRoutesTemplateResponse"): localization.GrRoutesTemplateResponse = gr.routes.templates.TemplateResponse diff --git a/kohya_gui/lora_gui.py b/kohya_gui/lora_gui.py index 914fff5fc..e4eed96d7 100644 --- a/kohya_gui/lora_gui.py +++ b/kohya_gui/lora_gui.py @@ -17,6 +17,7 @@ scriptdir, validate_paths, ) +from .class_accelerate_launch import AccelerateLaunch from .class_configuration_file import ConfigurationFile from .class_source_model import SourceModel from .class_basic_training import BasicTraining @@ -120,6 +121,7 @@ def save_configuration( save_model_as, shuffle_caption, save_state, + save_state_on_train_end, resume, prior_loss_weight, text_encoder_lr, @@ -129,11 +131,13 @@ def save_configuration( dim_from_weights, color_aug, flip_aug, + masked_loss, clip_skip, num_processes, num_machines, multi_gpu, gpu_ids, + main_process_port, gradient_accumulation_steps, mem_eff_attn, output_name, @@ -160,9 +164,12 @@ def save_configuration( max_grad_norm, noise_offset_type, noise_offset, + noise_offset_random_strength, adaptive_noise_scale, multires_noise_iterations, multires_noise_discount, + ip_noise_gamma, + ip_noise_gamma_random_strength, LoRA_type, factor, bypass_mode, @@ -216,6 +223,7 @@ def save_configuration( vae, LyCORIS_preset, debiased_estimation_loss, + extra_accelerate_launch_args, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -301,6 +309,7 @@ def open_configuration( save_model_as, shuffle_caption, save_state, + save_state_on_train_end, resume, prior_loss_weight, text_encoder_lr, @@ -310,11 +319,13 @@ def open_configuration( dim_from_weights, color_aug, flip_aug, + masked_loss, clip_skip, num_processes, num_machines, multi_gpu, gpu_ids, + main_process_port, gradient_accumulation_steps, mem_eff_attn, output_name, @@ -341,9 +352,12 @@ def open_configuration( max_grad_norm, noise_offset_type, noise_offset, + noise_offset_random_strength, adaptive_noise_scale, multires_noise_iterations, multires_noise_discount, + ip_noise_gamma, + ip_noise_gamma_random_strength, LoRA_type, factor, bypass_mode, @@ -397,6 +411,7 @@ def open_configuration( vae, LyCORIS_preset, debiased_estimation_loss, + extra_accelerate_launch_args, training_preset, ): # Get list of function parameters and values @@ -510,6 +525,7 @@ def train_model( save_model_as, shuffle_caption, save_state, + save_state_on_train_end, resume, prior_loss_weight, text_encoder_lr, @@ -519,11 +535,13 @@ def train_model( dim_from_weights, color_aug, flip_aug, + masked_loss, clip_skip, num_processes, num_machines, multi_gpu, gpu_ids, + main_process_port, gradient_accumulation_steps, mem_eff_attn, output_name, @@ -550,9 +568,12 @@ def train_model( max_grad_norm, noise_offset_type, noise_offset, + noise_offset_random_strength, adaptive_noise_scale, multires_noise_iterations, multires_noise_discount, + ip_noise_gamma, + ip_noise_gamma_random_strength, LoRA_type, factor, bypass_mode, @@ -606,6 +627,7 @@ def train_model( vae, LyCORIS_preset, debiased_estimation_loss, + extra_accelerate_launch_args, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -765,12 +787,15 @@ def train_model( run_cmd = "accelerate launch" - run_cmd += run_cmd_advanced_training( + run_cmd += AccelerateLaunch.run_cmd( num_processes=num_processes, num_machines=num_machines, multi_gpu=multi_gpu, gpu_ids=gpu_ids, + main_process_port=main_process_port, num_cpu_threads_per_process=num_cpu_threads_per_process, + mixed_precision=mixed_precision, + extra_accelerate_launch_args=extra_accelerate_launch_args, ) if sdxl: @@ -880,22 +905,21 @@ def train_model( ) # Convert learning rates to float once and store the result for re-use learning_rate = float(learning_rate) if learning_rate is not None else 0.0 - text_encoder_lr_float = float(text_encoder_lr) if text_encoder_lr is not None else 0.0 + text_encoder_lr_float = ( + float(text_encoder_lr) if text_encoder_lr is not None else 0.0 + ) unet_lr_float = float(unet_lr) if unet_lr is not None else 0.0 # Determine the training configuration based on learning rate values # Sets flags for training specific components based on the provided learning rates. if float(learning_rate) == unet_lr_float == text_encoder_lr_float == 0: - output_message( - msg="Please input learning rate values.", headless=headless_bool - ) + output_message(msg="Please input learning rate values.", headless=headless_bool) return # Flag to train text encoder only if its learning rate is non-zero and unet's is zero. network_train_text_encoder_only = text_encoder_lr_float != 0 and unet_lr_float == 0 # Flag to train unet only if its learning rate is non-zero and text encoder's is zero. network_train_unet_only = text_encoder_lr_float == 0 and unet_lr_float != 0 - # Define a dictionary of parameters run_cmd_params = { "adaptive_noise_scale": adaptive_noise_scale, @@ -917,11 +941,14 @@ def train_model( "enable_bucket": enable_bucket, "epoch": epoch, "flip_aug": flip_aug, + "masked_loss": masked_loss, "fp8_base": fp8_base, "full_bf16": full_bf16, "full_fp16": full_fp16, "gradient_accumulation_steps": gradient_accumulation_steps, "gradient_checkpointing": gradient_checkpointing, + "ip_noise_gamma": ip_noise_gamma, + "ip_noise_gamma_random_strength": ip_noise_gamma_random_strength, "keep_tokens": keep_tokens, "learning_rate": learning_rate, "logging_dir": logging_dir, @@ -957,6 +984,7 @@ def train_model( "network_train_text_encoder_only": network_train_text_encoder_only, "no_half_vae": True if sdxl and sdxl_no_half_vae else None, "noise_offset": noise_offset, + "noise_offset_random_strength": noise_offset_random_strength, "noise_offset_type": noise_offset_type, "optimizer": optimizer, "optimizer_args": optimizer_args, @@ -975,6 +1003,7 @@ def train_model( "save_model_as": save_model_as, "save_precision": save_precision, "save_state": save_state, + "save_state_on_train_end": save_state_on_train_end, "scale_v_pred_loss_like_noise_pred": scale_v_pred_loss_like_noise_pred, "scale_weight_norms": scale_weight_norms, "seed": seed, @@ -1036,6 +1065,7 @@ def train_model( env["PYTHONPATH"] = ( rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" ) + env["TF_ENABLE_ONEDNN_OPTS"] = "0" # Run the command executor.execute_command(run_cmd=run_cmd, env=env) @@ -1057,6 +1087,10 @@ def lora_tab( gr.Markdown( "Train a custom model using kohya train network LoRA python code..." ) + + with gr.Accordion("Accelerate launch", open=False), gr.Column(): + accelerate_launch = AccelerateLaunch(config=config) + with gr.Column(): source_model = SourceModel( save_model_as_choices=[ @@ -1091,673 +1125,674 @@ def list_presets(path): return json_files - training_preset = gr.Dropdown( - label="Presets", - choices=[""] + list_presets(rf"{presets_dir}/lora"), - elem_id="myDropdown", - value="none", - ) + with gr.Accordion("Basic", open="True"): + training_preset = gr.Dropdown( + label="Presets", + choices=[""] + list_presets(rf"{presets_dir}/lora"), + elem_id="myDropdown", + value="none", + ) - with gr.Group(elem_id="basic_tab"): - with gr.Row(): - LoRA_type = gr.Dropdown( - label="LoRA type", - choices=[ - "Kohya DyLoRA", - "Kohya LoCon", - "LoRA-FA", - "LyCORIS/iA3", - "LyCORIS/BOFT", - "LyCORIS/Diag-OFT", - "LyCORIS/DyLoRA", - "LyCORIS/GLoRA", - "LyCORIS/LoCon", - "LyCORIS/LoHa", - "LyCORIS/LoKr", - "LyCORIS/Native Fine-Tuning", - "Standard", - ], - value="Standard", - ) - LyCORIS_preset = gr.Dropdown( - label="LyCORIS Preset", - choices=[ - "attn-mlp", - "attn-only", - "full", - "full-lin", - "unet-transformer-only", - "unet-convblock-only", - ], - value="full", - visible=False, - interactive=True, - # info="https://github.com/KohakuBlueleaf/LyCORIS/blob/0006e2ffa05a48d8818112d9f70da74c0cd30b99/docs/Preset.md" + with gr.Group(elem_id="basic_tab"): + with gr.Row(): + LoRA_type = gr.Dropdown( + label="LoRA type", + choices=[ + "Kohya DyLoRA", + "Kohya LoCon", + "LoRA-FA", + "LyCORIS/iA3", + "LyCORIS/BOFT", + "LyCORIS/Diag-OFT", + "LyCORIS/DyLoRA", + "LyCORIS/GLoRA", + "LyCORIS/LoCon", + "LyCORIS/LoHa", + "LyCORIS/LoKr", + "LyCORIS/Native Fine-Tuning", + "Standard", + ], + value="Standard", + ) + LyCORIS_preset = gr.Dropdown( + label="LyCORIS Preset", + choices=[ + "attn-mlp", + "attn-only", + "full", + "full-lin", + "unet-transformer-only", + "unet-convblock-only", + ], + value="full", + visible=False, + interactive=True, + # info="https://github.com/KohakuBlueleaf/LyCORIS/blob/0006e2ffa05a48d8818112d9f70da74c0cd30b99/docs/Preset.md" + ) + with gr.Group(): + with gr.Row(): + lora_network_weights = gr.Textbox( + label="LoRA network weights", + placeholder="(Optional)", + info="Path to an existing LoRA network weights to resume training from", + ) + lora_network_weights_file = gr.Button( + document_symbol, + elem_id="open_folder_small", + elem_classes=["tool"], + visible=(not headless), + ) + lora_network_weights_file.click( + get_any_file_path, + inputs=[lora_network_weights], + outputs=lora_network_weights, + show_progress=False, + ) + dim_from_weights = gr.Checkbox( + label="DIM from weights", + value=False, + info="Automatically determine the dim(rank) from the weight file.", + ) + basic_training = BasicTraining( + learning_rate_value="0.0001", + lr_scheduler_value="cosine", + lr_warmup_value="10", + sdxl_checkbox=source_model.sdxl_checkbox, + config=config, ) - with gr.Group(): + + with gr.Row(): + text_encoder_lr = gr.Number( + label="Text Encoder learning rate", + value="0.0001", + info="(Optional)", + minimum=0, + maximum=1, + ) + + unet_lr = gr.Number( + label="Unet learning rate", + value="0.0001", + info="(Optional)", + minimum=0, + maximum=1, + ) + + # Add SDXL Parameters + sdxl_params = SDXLParameters(source_model.sdxl_checkbox, config=config) + + # LyCORIS Specific parameters + with gr.Accordion("LyCORIS", visible=False) as lycoris_accordion: + with gr.Row(): + factor = gr.Slider( + label="LoKr factor", + value=-1, + minimum=-1, + maximum=64, + step=1, + visible=False, + ) + bypass_mode = gr.Checkbox( + value=False, + label="Bypass mode", + info="Designed for bnb 8bit/4bit linear layer. (QLyCORIS)", + visible=False, + ) + dora_wd = gr.Checkbox( + value=False, + label="DoRA Weight Decompose", + info="Enable the DoRA method for these algorithms", + visible=False, + ) + use_cp = gr.Checkbox( + value=False, + label="Use CP decomposition", + info="A two-step approach utilizing tensor decomposition and fine-tuning to accelerate convolution layers in large neural networks, resulting in significant CPU speedups with minor accuracy drops.", + visible=False, + ) + use_tucker = gr.Checkbox( + value=False, + label="Use Tucker decomposition", + info="Efficiently decompose tensor shapes, resulting in a sequence of convolution layers with varying dimensions and Hadamard product implementation through multiplication of two distinct tensors.", + visible=False, + ) + use_scalar = gr.Checkbox( + value=False, + label="Use Scalar", + info="Train an additional scalar in front of the weight difference, use a different weight initialization strategy.", + visible=False, + ) with gr.Row(): - lora_network_weights = gr.Textbox( - label="LoRA network weights", - placeholder="(Optional)", - info="Path to an existing LoRA network weights to resume training from", + rank_dropout_scale = gr.Checkbox( + value=False, + label="Rank Dropout Scale", + info="Adjusts the scale of the rank dropout to maintain the average dropout rate, ensuring more consistent regularization across different layers.", + visible=False, ) - lora_network_weights_file = gr.Button( - document_symbol, - elem_id="open_folder_small", - elem_classes=["tool"], - visible=(not headless), + constrain = gr.Number( + value="0.0", + label="Constrain OFT", + info="Limits the norm of the oft_blocks, ensuring that their magnitude does not exceed a specified threshold, thus controlling the extent of the transformation applied.", + visible=False, ) - lora_network_weights_file.click( - get_any_file_path, - inputs=[lora_network_weights], - outputs=lora_network_weights, - show_progress=False, + rescaled = gr.Checkbox( + value=False, + label="Rescaled OFT", + info="applies an additional scaling factor to the oft_blocks, allowing for further adjustment of their impact on the model's transformations.", + visible=False, ) - dim_from_weights = gr.Checkbox( - label="DIM from weights", + train_norm = gr.Checkbox( value=False, - info="Automatically determine the dim(rank) from the weight file.", + label="Train Norm", + info="Selects trainable layers in a network, but trains normalization layers identically across methods as they lack matrix decomposition.", + visible=False, ) - basic_training = BasicTraining( - learning_rate_value="0.0001", - lr_scheduler_value="cosine", - lr_warmup_value="10", - sdxl_checkbox=source_model.sdxl_checkbox, - ) - - with gr.Row(): - text_encoder_lr = gr.Number( - label="Text Encoder learning rate", - value="0.0001", - info="(Optional)", - minimum=0, - maximum=1, - ) - - unet_lr = gr.Number( - label="Unet learning rate", - value="0.0001", - info="(Optional)", - minimum=0, - maximum=1, - ) - - # Add SDXL Parameters - sdxl_params = SDXLParameters(source_model.sdxl_checkbox) - - # LyCORIS Specific parameters - with gr.Accordion("LyCORIS", visible=False) as lycoris_accordion: - with gr.Row(): - factor = gr.Slider( - label="LoKr factor", - value=-1, - minimum=-1, - maximum=64, + decompose_both = gr.Checkbox( + value=False, + label="LoKr decompose both", + info="Controls whether both input and output dimensions of the layer's weights are decomposed into smaller matrices for reparameterization.", + visible=False, + ) + train_on_input = gr.Checkbox( + value=True, + label="iA3 train on input", + info="Set if we change the information going into the system (True) or the information coming out of it (False).", + visible=False, + ) + with gr.Row() as network_row: + network_dim = gr.Slider( + minimum=1, + maximum=512, + label="Network Rank (Dimension)", + value=8, step=1, - visible=False, + interactive=True, ) - bypass_mode = gr.Checkbox( - value=False, - label="Bypass mode", - info="Designed for bnb 8bit/4bit linear layer. (QLyCORIS)", - visible=False, - ) - dora_wd = gr.Checkbox( - value=False, - label="DoRA Weight Decompose", - info="Enable the DoRA method for these algorithms", - visible=False, + network_alpha = gr.Slider( + minimum=0.00001, + maximum=1024, + label="Network Alpha", + value=1, + step=0.00001, + interactive=True, + info="alpha for LoRA weight scaling", ) - use_cp = gr.Checkbox( - value=False, - label="Use CP decomposition", - info="A two-step approach utilizing tensor decomposition and fine-tuning to accelerate convolution layers in large neural networks, resulting in significant CPU speedups with minor accuracy drops.", - visible=False, - ) - use_tucker = gr.Checkbox( - value=False, - label="Use Tucker decomposition", - info="Efficiently decompose tensor shapes, resulting in a sequence of convolution layers with varying dimensions and Hadamard product implementation through multiplication of two distinct tensors.", - visible=False, + with gr.Row(visible=False) as convolution_row: + # locon= gr.Checkbox(label='Train a LoCon instead of a general LoRA (does not support v2 base models) (may not be able to some utilities now)', value=False) + conv_dim = gr.Slider( + minimum=0, + maximum=512, + value=1, + step=1, + label="Convolution Rank (Dimension)", ) - use_scalar = gr.Checkbox( - value=False, - label="Use Scalar", - info="Train an additional scalar in front of the weight difference, use a different weight initialization strategy.", - visible=False, + conv_alpha = gr.Slider( + minimum=0, + maximum=512, + value=1, + step=1, + label="Convolution Alpha", ) with gr.Row(): - rank_dropout_scale = gr.Checkbox( - value=False, - label="Rank Dropout Scale", - info="Adjusts the scale of the rank dropout to maintain the average dropout rate, ensuring more consistent regularization across different layers.", - visible=False, + scale_weight_norms = gr.Slider( + label="Scale weight norms", + value=0, + minimum=0, + maximum=10, + step=0.01, + info="Max Norm Regularization is a technique to stabilize network training by limiting the norm of network weights. It may be effective in suppressing overfitting of LoRA and improving stability when used with other LoRAs. See PR #545 on kohya_ss/sd_scripts repo for details. Recommended setting: 1. Higher is weaker, lower is stronger.", + interactive=True, ) - constrain = gr.Number( - value="0.0", - label="Constrain OFT", - info="Limits the norm of the oft_blocks, ensuring that their magnitude does not exceed a specified threshold, thus controlling the extent of the transformation applied.", - visible=False, + network_dropout = gr.Slider( + label="Network dropout", + value=0, + minimum=0, + maximum=1, + step=0.01, + info="Is a normal probability dropout at the neuron level. In the case of LoRA, it is applied to the output of down. Recommended range 0.1 to 0.5", ) - rescaled = gr.Checkbox( - value=False, - label="Rescaled OFT", - info="applies an additional scaling factor to the oft_blocks, allowing for further adjustment of their impact on the model's transformations.", - visible=False, + rank_dropout = gr.Slider( + label="Rank dropout", + value=0, + minimum=0, + maximum=1, + step=0.01, + info="can specify `rank_dropout` to dropout each rank with specified probability. Recommended range 0.1 to 0.3", ) - train_norm = gr.Checkbox( - value=False, - label="Train Norm", - info="Selects trainable layers in a network, but trains normalization layers identically across methods as they lack matrix decomposition.", - visible=False, + module_dropout = gr.Slider( + label="Module dropout", + value=0.0, + minimum=0.0, + maximum=1.0, + step=0.01, + info="can specify `module_dropout` to dropout each rank with specified probability. Recommended range 0.1 to 0.3", ) - decompose_both = gr.Checkbox( - value=False, - label="LoKr decompose both", - info="Controls whether both input and output dimensions of the layer's weights are decomposed into smaller matrices for reparameterization.", - visible=False, - ) - train_on_input = gr.Checkbox( - value=True, - label="iA3 train on input", - info="Set if we change the information going into the system (True) or the information coming out of it (False).", - visible=False, + with gr.Row(visible=False): + unit = gr.Slider( + minimum=1, + maximum=64, + label="DyLoRA Unit / Block size", + value=1, + step=1, + interactive=True, ) - with gr.Row() as network_row: - network_dim = gr.Slider( - minimum=1, - maximum=512, - label="Network Rank (Dimension)", - value=8, - step=1, - interactive=True, - ) - network_alpha = gr.Slider( - minimum=0.00001, - maximum=1024, - label="Network Alpha", - value=1, - step=0.00001, - interactive=True, - info="alpha for LoRA weight scaling", - ) - with gr.Row(visible=False) as convolution_row: - # locon= gr.Checkbox(label='Train a LoCon instead of a general LoRA (does not support v2 base models) (may not be able to some utilities now)', value=False) - conv_dim = gr.Slider( - minimum=0, - maximum=512, - value=1, - step=1, - label="Convolution Rank (Dimension)", - ) - conv_alpha = gr.Slider( - minimum=0, - maximum=512, - value=1, - step=1, - label="Convolution Alpha", - ) - with gr.Row(): - scale_weight_norms = gr.Slider( - label="Scale weight norms", - value=0, - minimum=0, - maximum=10, - step=0.01, - info="Max Norm Regularization is a technique to stabilize network training by limiting the norm of network weights. It may be effective in suppressing overfitting of LoRA and improving stability when used with other LoRAs. See PR #545 on kohya_ss/sd_scripts repo for details. Recommended setting: 1. Higher is weaker, lower is stronger.", - interactive=True, - ) - network_dropout = gr.Slider( - label="Network dropout", - value=0, - minimum=0, - maximum=1, - step=0.01, - info="Is a normal probability dropout at the neuron level. In the case of LoRA, it is applied to the output of down. Recommended range 0.1 to 0.5", - ) - rank_dropout = gr.Slider( - label="Rank dropout", - value=0, - minimum=0, - maximum=1, - step=0.01, - info="can specify `rank_dropout` to dropout each rank with specified probability. Recommended range 0.1 to 0.3", - ) - module_dropout = gr.Slider( - label="Module dropout", - value=0.0, - minimum=0.0, - maximum=1.0, - step=0.01, - info="can specify `module_dropout` to dropout each rank with specified probability. Recommended range 0.1 to 0.3", - ) - with gr.Row(visible=False): - unit = gr.Slider( - minimum=1, - maximum=64, - label="DyLoRA Unit / Block size", - value=1, - step=1, - interactive=True, - ) - - # Show or hide LoCon conv settings depending on LoRA type selection - def update_LoRA_settings( - LoRA_type, - conv_dim, - network_dim, - ): - log.info("LoRA type changed...") - - lora_settings_config = { - "network_row": { - "gr_type": gr.Row, - "update_params": { - "visible": LoRA_type - in { - "Kohya DyLoRA", - "Kohya LoCon", - "LoRA-FA", - "LyCORIS/BOFT", - "LyCORIS/Diag-OFT", - "LyCORIS/DyLoRA", - "LyCORIS/GLoRA", - "LyCORIS/LoCon", - "LyCORIS/LoHa", - "LyCORIS/LoKr", - "Standard", + # Show or hide LoCon conv settings depending on LoRA type selection + def update_LoRA_settings( + LoRA_type, + conv_dim, + network_dim, + ): + log.info("LoRA type changed...") + + lora_settings_config = { + "network_row": { + "gr_type": gr.Row, + "update_params": { + "visible": LoRA_type + in { + "Kohya DyLoRA", + "Kohya LoCon", + "LoRA-FA", + "LyCORIS/BOFT", + "LyCORIS/Diag-OFT", + "LyCORIS/DyLoRA", + "LyCORIS/GLoRA", + "LyCORIS/LoCon", + "LyCORIS/LoHa", + "LyCORIS/LoKr", + "Standard", + }, + }, + }, + "convolution_row": { + "gr_type": gr.Row, + "update_params": { + "visible": LoRA_type + in { + "LoCon", + "Kohya DyLoRA", + "Kohya LoCon", + "LoRA-FA", + "LyCORIS/BOFT", + "LyCORIS/Diag-OFT", + "LyCORIS/DyLoRA", + "LyCORIS/LoHa", + "LyCORIS/LoKr", + "LyCORIS/LoCon", + "LyCORIS/GLoRA", + }, }, }, - }, - "convolution_row": { - "gr_type": gr.Row, - "update_params": { - "visible": LoRA_type - in { - "LoCon", - "Kohya DyLoRA", - "Kohya LoCon", - "LoRA-FA", - "LyCORIS/BOFT", - "LyCORIS/Diag-OFT", - "LyCORIS/DyLoRA", - "LyCORIS/LoHa", - "LyCORIS/LoKr", - "LyCORIS/LoCon", - "LyCORIS/GLoRA", + "kohya_advanced_lora": { + "gr_type": gr.Row, + "update_params": { + "visible": LoRA_type + in { + "Standard", + "Kohya DyLoRA", + "Kohya LoCon", + "LoRA-FA", + }, }, }, - }, - "kohya_advanced_lora": { - "gr_type": gr.Row, - "update_params": { - "visible": LoRA_type - in { - "Standard", - "Kohya DyLoRA", - "Kohya LoCon", - "LoRA-FA", + "lora_network_weights": { + "gr_type": gr.Textbox, + "update_params": { + "visible": LoRA_type + in { + "Standard", + "LoCon", + "Kohya DyLoRA", + "Kohya LoCon", + "LoRA-FA", + "LyCORIS/BOFT", + "LyCORIS/Diag-OFT", + "LyCORIS/DyLoRA", + "LyCORIS/GLoRA", + "LyCORIS/LoHa", + "LyCORIS/LoCon", + "LyCORIS/LoKr", + }, + }, + }, + "lora_network_weights_file": { + "gr_type": gr.Button, + "update_params": { + "visible": LoRA_type + in { + "Standard", + "LoCon", + "Kohya DyLoRA", + "Kohya LoCon", + "LoRA-FA", + "LyCORIS/BOFT", + "LyCORIS/Diag-OFT", + "LyCORIS/DyLoRA", + "LyCORIS/GLoRA", + "LyCORIS/LoHa", + "LyCORIS/LoCon", + "LyCORIS/LoKr", + }, }, }, - }, - "lora_network_weights": { - "gr_type": gr.Textbox, - "update_params": { - "visible": LoRA_type - in { - "Standard", - "LoCon", - "Kohya DyLoRA", - "Kohya LoCon", - "LoRA-FA", - "LyCORIS/BOFT", - "LyCORIS/Diag-OFT", - "LyCORIS/DyLoRA", - "LyCORIS/GLoRA", - "LyCORIS/LoHa", - "LyCORIS/LoCon", - "LyCORIS/LoKr", + "dim_from_weights": { + "gr_type": gr.Checkbox, + "update_params": { + "visible": LoRA_type + in { + "Standard", + "LoCon", + "Kohya DyLoRA", + "Kohya LoCon", + "LoRA-FA", + "LyCORIS/BOFT", + "LyCORIS/Diag-OFT", + "LyCORIS/DyLoRA", + "LyCORIS/GLoRA", + "LyCORIS/LoHa", + "LyCORIS/LoCon", + "LyCORIS/LoKr", + } }, }, - }, - "lora_network_weights_file": { - "gr_type": gr.Button, - "update_params": { - "visible": LoRA_type - in { - "Standard", - "LoCon", - "Kohya DyLoRA", - "Kohya LoCon", - "LoRA-FA", - "LyCORIS/BOFT", - "LyCORIS/Diag-OFT", - "LyCORIS/DyLoRA", - "LyCORIS/GLoRA", - "LyCORIS/LoHa", - "LyCORIS/LoCon", - "LyCORIS/LoKr", + "factor": { + "gr_type": gr.Slider, + "update_params": { + "visible": LoRA_type + in { + "LyCORIS/LoKr", + }, }, }, - }, - "dim_from_weights": { - "gr_type": gr.Checkbox, - "update_params": { - "visible": LoRA_type - in { - "Standard", - "LoCon", - "Kohya DyLoRA", - "Kohya LoCon", - "LoRA-FA", - "LyCORIS/BOFT", - "LyCORIS/Diag-OFT", - "LyCORIS/DyLoRA", - "LyCORIS/GLoRA", - "LyCORIS/LoHa", - "LyCORIS/LoCon", - "LyCORIS/LoKr", - } + "conv_dim": { + "gr_type": gr.Slider, + "update_params": { + "maximum": ( + 100000 + if LoRA_type + in { + "LyCORIS/LoHa", + "LyCORIS/LoKr", + "LyCORIS/BOFT", + "LyCORIS/Diag-OFT", + } + else 512 + ), + "value": conv_dim, # if conv_dim > 512 else conv_dim, + }, }, - }, - "factor": { - "gr_type": gr.Slider, - "update_params": { - "visible": LoRA_type - in { - "LyCORIS/LoKr", + "network_dim": { + "gr_type": gr.Slider, + "update_params": { + "maximum": ( + 100000 + if LoRA_type + in { + "LyCORIS/LoHa", + "LyCORIS/LoKr", + "LyCORIS/BOFT", + "LyCORIS/Diag-OFT", + } + else 512 + ), + "value": network_dim, # if network_dim > 512 else network_dim, }, }, - }, - "conv_dim": { - "gr_type": gr.Slider, - "update_params": { - "maximum": ( - 100000 - if LoRA_type + "bypass_mode": { + "gr_type": gr.Checkbox, + "update_params": { + "visible": LoRA_type in { + "LyCORIS/LoCon", "LyCORIS/LoHa", "LyCORIS/LoKr", - "LyCORIS/BOFT", - "LyCORIS/Diag-OFT", - } - else 512 - ), - "value": conv_dim, # if conv_dim > 512 else conv_dim, + }, + }, }, - }, - "network_dim": { - "gr_type": gr.Slider, - "update_params": { - "maximum": ( - 100000 - if LoRA_type + "dora_wd": { + "gr_type": gr.Checkbox, + "update_params": { + "visible": LoRA_type in { + "LyCORIS/LoCon", "LyCORIS/LoHa", "LyCORIS/LoKr", - "LyCORIS/BOFT", - "LyCORIS/Diag-OFT", - } - else 512 - ), - "value": network_dim, # if network_dim > 512 else network_dim, - }, - }, - "bypass_mode": { - "gr_type": gr.Checkbox, - "update_params": { - "visible": LoRA_type - in { - "LyCORIS/LoCon", - "LyCORIS/LoHa", - "LyCORIS/LoKr", + }, }, }, - }, - "dora_wd": { - "gr_type": gr.Checkbox, - "update_params": { - "visible": LoRA_type - in { - "LyCORIS/LoCon", - "LyCORIS/LoHa", - "LyCORIS/LoKr", + "use_cp": { + "gr_type": gr.Checkbox, + "update_params": { + "visible": LoRA_type + in { + "LyCORIS/LoKr", + }, }, }, - }, - "use_cp": { - "gr_type": gr.Checkbox, - "update_params": { - "visible": LoRA_type - in { - "LyCORIS/LoKr", + "use_tucker": { + "gr_type": gr.Checkbox, + "update_params": { + "visible": LoRA_type + in { + "LyCORIS/BOFT", + "LyCORIS/Diag-OFT", + "LyCORIS/DyLoRA", + "LyCORIS/LoCon", + "LyCORIS/LoHa", + "LyCORIS/Native Fine-Tuning", + }, }, }, - }, - "use_tucker": { - "gr_type": gr.Checkbox, - "update_params": { - "visible": LoRA_type - in { - "LyCORIS/BOFT", - "LyCORIS/Diag-OFT", - "LyCORIS/DyLoRA", - "LyCORIS/LoCon", - "LyCORIS/LoHa", - "LyCORIS/Native Fine-Tuning", + "use_scalar": { + "gr_type": gr.Checkbox, + "update_params": { + "visible": LoRA_type + in { + "LyCORIS/BOFT", + "LyCORIS/Diag-OFT", + "LyCORIS/LoCon", + "LyCORIS/LoHa", + "LyCORIS/LoKr", + "LyCORIS/Native Fine-Tuning", + }, }, }, - }, - "use_scalar": { - "gr_type": gr.Checkbox, - "update_params": { - "visible": LoRA_type - in { - "LyCORIS/BOFT", - "LyCORIS/Diag-OFT", - "LyCORIS/LoCon", - "LyCORIS/LoHa", - "LyCORIS/LoKr", - "LyCORIS/Native Fine-Tuning", + "rank_dropout_scale": { + "gr_type": gr.Checkbox, + "update_params": { + "visible": LoRA_type + in { + "LyCORIS/BOFT", + "LyCORIS/Diag-OFT", + "LyCORIS/GLoRA", + "LyCORIS/LoCon", + "LyCORIS/LoHa", + "LyCORIS/LoKr", + "LyCORIS/Native Fine-Tuning", + }, }, }, - }, - "rank_dropout_scale": { - "gr_type": gr.Checkbox, - "update_params": { - "visible": LoRA_type - in { - "LyCORIS/BOFT", - "LyCORIS/Diag-OFT", - "LyCORIS/GLoRA", - "LyCORIS/LoCon", - "LyCORIS/LoHa", - "LyCORIS/LoKr", - "LyCORIS/Native Fine-Tuning", + "constrain": { + "gr_type": gr.Number, + "update_params": { + "visible": LoRA_type + in { + "LyCORIS/BOFT", + "LyCORIS/Diag-OFT", + }, }, }, - }, - "constrain": { - "gr_type": gr.Number, - "update_params": { - "visible": LoRA_type - in { - "LyCORIS/BOFT", - "LyCORIS/Diag-OFT", + "rescaled": { + "gr_type": gr.Checkbox, + "update_params": { + "visible": LoRA_type + in { + "LyCORIS/BOFT", + "LyCORIS/Diag-OFT", + }, }, }, - }, - "rescaled": { - "gr_type": gr.Checkbox, - "update_params": { - "visible": LoRA_type - in { - "LyCORIS/BOFT", - "LyCORIS/Diag-OFT", + "train_norm": { + "gr_type": gr.Checkbox, + "update_params": { + "visible": LoRA_type + in { + "LyCORIS/DyLoRA", + "LyCORIS/BOFT", + "LyCORIS/Diag-OFT", + "LyCORIS/GLoRA", + "LyCORIS/LoCon", + "LyCORIS/LoHa", + "LyCORIS/LoKr", + "LyCORIS/Native Fine-Tuning", + }, }, }, - }, - "train_norm": { - "gr_type": gr.Checkbox, - "update_params": { - "visible": LoRA_type - in { - "LyCORIS/DyLoRA", - "LyCORIS/BOFT", - "LyCORIS/Diag-OFT", - "LyCORIS/GLoRA", - "LyCORIS/LoCon", - "LyCORIS/LoHa", - "LyCORIS/LoKr", - "LyCORIS/Native Fine-Tuning", + "decompose_both": { + "gr_type": gr.Checkbox, + "update_params": { + "visible": LoRA_type in {"LyCORIS/LoKr"}, }, }, - }, - "decompose_both": { - "gr_type": gr.Checkbox, - "update_params": { - "visible": LoRA_type in {"LyCORIS/LoKr"}, - }, - }, - "train_on_input": { - "gr_type": gr.Checkbox, - "update_params": { - "visible": LoRA_type in {"LyCORIS/iA3"}, + "train_on_input": { + "gr_type": gr.Checkbox, + "update_params": { + "visible": LoRA_type in {"LyCORIS/iA3"}, + }, }, - }, - "scale_weight_norms": { - "gr_type": gr.Slider, - "update_params": { - "visible": LoRA_type - in { - "LoCon", - "Kohya DyLoRA", - "Kohya LoCon", - "LoRA-FA", - "LyCORIS/DyLoRA", - "LyCORIS/GLoRA", - "LyCORIS/LoHa", - "LyCORIS/LoCon", - "LyCORIS/LoKr", - "Standard", + "scale_weight_norms": { + "gr_type": gr.Slider, + "update_params": { + "visible": LoRA_type + in { + "LoCon", + "Kohya DyLoRA", + "Kohya LoCon", + "LoRA-FA", + "LyCORIS/DyLoRA", + "LyCORIS/GLoRA", + "LyCORIS/LoHa", + "LyCORIS/LoCon", + "LyCORIS/LoKr", + "Standard", + }, }, }, - }, - "network_dropout": { - "gr_type": gr.Slider, - "update_params": { - "visible": LoRA_type - in { - "LoCon", - "Kohya DyLoRA", - "Kohya LoCon", - "LoRA-FA", - "LyCORIS/BOFT", - "LyCORIS/Diag-OFT", - "LyCORIS/DyLoRA", - "LyCORIS/GLoRA", - "LyCORIS/LoCon", - "LyCORIS/LoHa", - "LyCORIS/LoKr", - "LyCORIS/Native Fine-Tuning", - "Standard", + "network_dropout": { + "gr_type": gr.Slider, + "update_params": { + "visible": LoRA_type + in { + "LoCon", + "Kohya DyLoRA", + "Kohya LoCon", + "LoRA-FA", + "LyCORIS/BOFT", + "LyCORIS/Diag-OFT", + "LyCORIS/DyLoRA", + "LyCORIS/GLoRA", + "LyCORIS/LoCon", + "LyCORIS/LoHa", + "LyCORIS/LoKr", + "LyCORIS/Native Fine-Tuning", + "Standard", + }, }, }, - }, - "rank_dropout": { - "gr_type": gr.Slider, - "update_params": { - "visible": LoRA_type - in { - "LoCon", - "Kohya DyLoRA", - "LyCORIS/BOFT", - "LyCORIS/Diag-OFT", - "LyCORIS/GLoRA", - "LyCORIS/LoCon", - "LyCORIS/LoHa", - "LyCORIS/LoKR", - "Kohya LoCon", - "LoRA-FA", - "LyCORIS/Native Fine-Tuning", - "Standard", + "rank_dropout": { + "gr_type": gr.Slider, + "update_params": { + "visible": LoRA_type + in { + "LoCon", + "Kohya DyLoRA", + "LyCORIS/BOFT", + "LyCORIS/Diag-OFT", + "LyCORIS/GLoRA", + "LyCORIS/LoCon", + "LyCORIS/LoHa", + "LyCORIS/LoKR", + "Kohya LoCon", + "LoRA-FA", + "LyCORIS/Native Fine-Tuning", + "Standard", + }, }, }, - }, - "module_dropout": { - "gr_type": gr.Slider, - "update_params": { - "visible": LoRA_type - in { - "LoCon", - "LyCORIS/BOFT", - "LyCORIS/Diag-OFT", - "Kohya DyLoRA", - "LyCORIS/GLoRA", - "LyCORIS/LoCon", - "LyCORIS/LoHa", - "LyCORIS/LoKR", - "Kohya LoCon", - "LyCORIS/Native Fine-Tuning", - "LoRA-FA", - "Standard", + "module_dropout": { + "gr_type": gr.Slider, + "update_params": { + "visible": LoRA_type + in { + "LoCon", + "LyCORIS/BOFT", + "LyCORIS/Diag-OFT", + "Kohya DyLoRA", + "LyCORIS/GLoRA", + "LyCORIS/LoCon", + "LyCORIS/LoHa", + "LyCORIS/LoKR", + "Kohya LoCon", + "LyCORIS/Native Fine-Tuning", + "LoRA-FA", + "Standard", + }, }, }, - }, - "LyCORIS_preset": { - "gr_type": gr.Dropdown, - "update_params": { - "visible": LoRA_type - in { - "LyCORIS/DyLoRA", - "LyCORIS/iA3", - "LyCORIS/BOFT", - "LyCORIS/Diag-OFT", - "LyCORIS/GLoRA", - "LyCORIS/LoCon", - "LyCORIS/LoHa", - "LyCORIS/LoKr", - "LyCORIS/Native Fine-Tuning", + "LyCORIS_preset": { + "gr_type": gr.Dropdown, + "update_params": { + "visible": LoRA_type + in { + "LyCORIS/DyLoRA", + "LyCORIS/iA3", + "LyCORIS/BOFT", + "LyCORIS/Diag-OFT", + "LyCORIS/GLoRA", + "LyCORIS/LoCon", + "LyCORIS/LoHa", + "LyCORIS/LoKr", + "LyCORIS/Native Fine-Tuning", + }, }, }, - }, - "unit": { - "gr_type": gr.Slider, - "update_params": { - "visible": LoRA_type - in { - "Kohya DyLoRA", - "LyCORIS/DyLoRA", + "unit": { + "gr_type": gr.Slider, + "update_params": { + "visible": LoRA_type + in { + "Kohya DyLoRA", + "LyCORIS/DyLoRA", + }, }, }, - }, - "lycoris_accordion": { - "gr_type": gr.Accordion, - "update_params": { - "visible": LoRA_type - in { - "LyCORIS/DyLoRA", - "LyCORIS/iA3", - "LyCORIS/BOFT", - "LyCORIS/Diag-OFT", - "LyCORIS/GLoRA", - "LyCORIS/LoCon", - "LyCORIS/LoHa", - "LyCORIS/LoKr", - "LyCORIS/Native Fine-Tuning", + "lycoris_accordion": { + "gr_type": gr.Accordion, + "update_params": { + "visible": LoRA_type + in { + "LyCORIS/DyLoRA", + "LyCORIS/iA3", + "LyCORIS/BOFT", + "LyCORIS/Diag-OFT", + "LyCORIS/GLoRA", + "LyCORIS/LoCon", + "LyCORIS/LoHa", + "LyCORIS/LoKr", + "LyCORIS/Native Fine-Tuning", + }, }, }, - }, - } + } - results = [] - for attr, settings in lora_settings_config.items(): - update_params = settings["update_params"] + results = [] + for attr, settings in lora_settings_config.items(): + update_params = settings["update_params"] - results.append(settings["gr_type"](**update_params)) + results.append(settings["gr_type"](**update_params)) - return tuple(results) + return tuple(results) with gr.Accordion("Advanced", open=False, elem_id="advanced_tab"): # with gr.Accordion('Advanced Configuration', open=False): @@ -1818,7 +1853,7 @@ def update_LoRA_settings( ) with gr.Accordion("Samples", open=False, elem_id="samples_tab"): - sample = SampleImages() + sample = SampleImages(config=config) LoRA_type.change( update_LoRA_settings, @@ -1868,6 +1903,7 @@ def update_LoRA_settings( output_dir_input=folders.output_dir, logging_dir_input=folders.logging_dir, headless=headless, + config=config, ) gradio_dataset_balancing_tab(headless=headless) @@ -1918,10 +1954,10 @@ def update_LoRA_settings( basic_training.train_batch_size, basic_training.epoch, basic_training.save_every_n_epochs, - basic_training.mixed_precision, + accelerate_launch.mixed_precision, source_model.save_precision, basic_training.seed, - basic_training.num_cpu_threads_per_process, + accelerate_launch.num_cpu_threads_per_process, basic_training.cache_latents, basic_training.cache_latents_to_disk, basic_training.caption_extension, @@ -1937,6 +1973,7 @@ def update_LoRA_settings( source_model.save_model_as, advanced_training.shuffle_caption, advanced_training.save_state, + advanced_training.save_state_on_train_end, advanced_training.resume, advanced_training.prior_loss_weight, text_encoder_lr, @@ -1946,11 +1983,13 @@ def update_LoRA_settings( dim_from_weights, advanced_training.color_aug, advanced_training.flip_aug, + advanced_training.masked_loss, advanced_training.clip_skip, - advanced_training.num_processes, - advanced_training.num_machines, - advanced_training.multi_gpu, - advanced_training.gpu_ids, + accelerate_launch.num_processes, + accelerate_launch.num_machines, + accelerate_launch.multi_gpu, + accelerate_launch.gpu_ids, + accelerate_launch.main_process_port, advanced_training.gradient_accumulation_steps, advanced_training.mem_eff_attn, source_model.output_name, @@ -1977,9 +2016,12 @@ def update_LoRA_settings( basic_training.max_grad_norm, advanced_training.noise_offset_type, advanced_training.noise_offset, + advanced_training.noise_offset_random_strength, advanced_training.adaptive_noise_scale, advanced_training.multires_noise_iterations, advanced_training.multires_noise_discount, + advanced_training.ip_noise_gamma, + advanced_training.ip_noise_gamma_random_strength, LoRA_type, factor, bypass_mode, @@ -2033,6 +2075,7 @@ def update_LoRA_settings( advanced_training.vae, LyCORIS_preset, advanced_training.debiased_estimation_loss, + accelerate_launch.extra_accelerate_launch_args, ] configuration.button_open_config.click( diff --git a/kohya_gui/manual_caption_gui.py b/kohya_gui/manual_caption_gui.py index 9494b5dd2..2a80bfc06 100644 --- a/kohya_gui/manual_caption_gui.py +++ b/kohya_gui/manual_caption_gui.py @@ -11,7 +11,7 @@ log = setup_logging() IMAGES_TO_SHOW = 5 -IMAGE_EXTENSIONS = ('.png', '.jpg', '.jpeg', '.webp', '.bmp') +IMAGE_EXTENSIONS = (".png", ".jpg", ".jpeg", ".webp", ".bmp") auto_save = True @@ -28,7 +28,7 @@ def _get_quick_tags(quick_tags_text): """ Gets a list of tags from the quick tags text box """ - quick_tags = [t.strip() for t in quick_tags_text.split(',') if t.strip()] + quick_tags = [t.strip() for t in quick_tags_text.split(",") if t.strip()] quick_tags_set = set(quick_tags) return quick_tags, quick_tags_set @@ -38,34 +38,31 @@ def _get_tag_checkbox_updates(caption, quick_tags, quick_tags_set): Updates a list of caption checkboxes to show possible tags and tags already included in the caption """ - caption_tags_have = [c.strip() for c in caption.split(',') if c.strip()] - caption_tags_unique = [ - t for t in caption_tags_have if t not in quick_tags_set - ] + caption_tags_have = [c.strip() for c in caption.split(",") if c.strip()] + caption_tags_unique = [t for t in caption_tags_have if t not in quick_tags_set] caption_tags_all = quick_tags + caption_tags_unique - return gr.CheckboxGroup( - choices=caption_tags_all, value=caption_tags_have - ) + return gr.CheckboxGroup(choices=caption_tags_all, value=caption_tags_have) def paginate_go(page, max_page): try: page = float(page) except: - msgbox(f'Invalid page num: {page}') + msgbox(f"Invalid page num: {page}") return return paginate(page, max_page, 0) + def paginate(page, max_page, page_change): return int(max(min(page + page_change, max_page), 1)) def save_caption(caption, caption_ext, image_file, images_dir): caption_path = _get_caption_path(image_file, images_dir, caption_ext) - with open(caption_path, 'w+', encoding='utf8') as f: + with open(caption_path, "w+", encoding="utf8") as f: f.write(caption) - log.info(f'Wrote captions to {caption_path}') + log.info(f"Wrote captions to {caption_path}") def update_quick_tags(quick_tags_text, *image_caption_texts): @@ -101,7 +98,7 @@ def update_image_tags( output_tags = [t for t in quick_tags if t in selected_tags_set] + [ t for t in selected_tags if t not in quick_tags_set ] - caption = ', '.join(output_tags) + caption = ", ".join(output_tags) if auto_save: save_caption(caption, caption_ext, image_file, images_dir) @@ -122,50 +119,46 @@ def empty_return(): # Check for images_dir if not images_dir: - msgbox('Image folder is missing...') + msgbox("Image folder is missing...") return empty_return() if not os.path.exists(images_dir): - msgbox('Image folder does not exist...') + msgbox("Image folder does not exist...") return empty_return() if not caption_ext: - msgbox('Please provide an extension for the caption files.') + msgbox("Please provide an extension for the caption files.") return empty_return() if quick_tags_text: if not boolbox( - f'Are you sure you wish to overwrite the current quick tags?', - choices=('Yes', 'No'), + f"Are you sure you wish to overwrite the current quick tags?", + choices=("Yes", "No"), ): return empty_return() images_list = os.listdir(images_dir) - image_files = [ - f for f in images_list if f.lower().endswith(IMAGE_EXTENSIONS) - ] + image_files = [f for f in images_list if f.lower().endswith(IMAGE_EXTENSIONS)] # Use a set for lookup but store order with list tags = [] tags_set = set() for image_file in image_files: - caption_file_path = _get_caption_path( - image_file, images_dir, caption_ext - ) + caption_file_path = _get_caption_path(image_file, images_dir, caption_ext) if os.path.exists(caption_file_path): - with open(caption_file_path, 'r', encoding='utf8') as f: + with open(caption_file_path, "r", encoding="utf8") as f: caption = f.read() - for tag in caption.split(','): + for tag in caption.split(","): tag = tag.strip() tag_key = tag.lower() if not tag_key in tags_set: # Ignore extra spaces - total_words = len(re.findall(r'\s+', tag)) + 1 + total_words = len(re.findall(r"\s+", tag)) + 1 if total_words <= ignore_load_tags_word_count: tags.append(tag) tags_set.add(tag_key) - return ', '.join(tags) + return ", ".join(tags) def load_images(images_dir, caption_ext, loaded_images_dir, page, max_page): @@ -180,15 +173,15 @@ def empty_return(): # Check for images_dir if not images_dir: - msgbox('Image folder is missing...') + msgbox("Image folder is missing...") return empty_return() if not os.path.exists(images_dir): - msgbox('Image folder does not exist...') + msgbox("Image folder does not exist...") return empty_return() if not caption_ext: - msgbox('Please provide an extension for the caption files.') + msgbox("Please provide an extension for the caption files.") return empty_return() # Load Images @@ -212,12 +205,10 @@ def update_images( # Load Images images_list = os.listdir(images_dir) - image_files = [ - f for f in images_list if f.lower().endswith(IMAGE_EXTENSIONS) - ] + image_files = [f for f in images_list if f.lower().endswith(IMAGE_EXTENSIONS)] # Quick tags - quick_tags, quick_tags_set = _get_quick_tags(quick_tags_text or '') + quick_tags, quick_tags_set = _get_quick_tags(quick_tags_text or "") # Display Images rows = [] @@ -231,22 +222,18 @@ def update_images( show_row = image_index < len(image_files) image_path = None - caption = '' + caption = "" tag_checkboxes = None if show_row: image_file = image_files[image_index] image_path = os.path.join(images_dir, image_file) - caption_file_path = _get_caption_path( - image_file, images_dir, caption_ext - ) + caption_file_path = _get_caption_path(image_file, images_dir, caption_ext) if os.path.exists(caption_file_path): - with open(caption_file_path, 'r', encoding='utf8') as f: + with open(caption_file_path, "r", encoding="utf8") as f: caption = f.read() - tag_checkboxes = _get_tag_checkbox_updates( - caption, quick_tags, quick_tags_set - ) + tag_checkboxes = _get_tag_checkbox_updates(caption, quick_tags, quick_tags_set) rows.append(gr.Row(visible=show_row)) image_paths.append(image_path) captions.append(caption) @@ -266,7 +253,11 @@ def update_images( def gradio_manual_caption_gui_tab(headless=False, default_images_dir=None): from .common_gui import create_refresh_button - default_images_dir = default_images_dir if default_images_dir is not None else os.path.join(scriptdir, "data") + default_images_dir = ( + default_images_dir + if default_images_dir is not None + else os.path.join(scriptdir, "data") + ) current_images_dir = default_images_dir # Function to list directories @@ -276,39 +267,45 @@ def list_images_dirs(path): current_images_dir = path return list(list_dirs(path)) - with gr.Tab('Manual Captioning'): - gr.Markdown( - 'This utility allows quick captioning and tagging of images.' - ) + with gr.Tab("Manual Captioning"): + gr.Markdown("This utility allows quick captioning and tagging of images.") page = gr.Number(-1, visible=False) max_page = gr.Number(1, visible=False) loaded_images_dir = gr.Text(visible=False) with gr.Group(), gr.Row(): images_dir = gr.Dropdown( - label='Image folder to caption (containing the images to caption)', + label="Image folder to caption (containing the images to caption)", choices=[""] + list_images_dirs(default_images_dir), value="", interactive=True, allow_custom_value=True, ) - create_refresh_button(images_dir, lambda: None, lambda: {"choices": list_images_dirs(current_images_dir)},"open_folder_small") + create_refresh_button( + images_dir, + lambda: None, + lambda: {"choices": list_images_dirs(current_images_dir)}, + "open_folder_small", + ) folder_button = gr.Button( - '📂', elem_id='open_folder_small', elem_classes=['tool'], visible=(not headless) + "📂", + elem_id="open_folder_small", + elem_classes=["tool"], + visible=(not headless), ) folder_button.click( get_folder_path, outputs=images_dir, show_progress=False, ) - load_images_button = gr.Button('Load', elem_id='open_folder') + load_images_button = gr.Button("Load", elem_id="open_folder") caption_ext = gr.Textbox( - label='Caption file extension', - placeholder='Extension for caption file (e.g., .caption, .txt)', - value='.txt', + label="Caption file extension", + placeholder="Extension for caption file (e.g., .caption, .txt)", + value=".txt", interactive=True, ) auto_save = gr.Checkbox( - label='Autosave', info='Options', value=True, interactive=True + label="Autosave", info="Options", value=True, interactive=True ) images_dir.change( @@ -321,39 +318,39 @@ def list_images_dirs(path): # Caption Section with gr.Group(), gr.Row(): quick_tags_text = gr.Textbox( - label='Quick Tags', - placeholder='Comma separated list of tags', + label="Quick Tags", + placeholder="Comma separated list of tags", interactive=True, ) - import_tags_button = gr.Button('Import', elem_id='open_folder') + import_tags_button = gr.Button("Import", elem_id="open_folder") ignore_load_tags_word_count = gr.Slider( minimum=1, maximum=100, value=3, step=1, - label='Ignore Imported Tags Above Word Count', + label="Ignore Imported Tags Above Word Count", interactive=True, ) # Next/Prev section generator def render_pagination(): - gr.Button('< Prev', elem_id='open_folder').click( + gr.Button("< Prev", elem_id="open_folder").click( paginate, inputs=[page, max_page, gr.Number(-1, visible=False)], outputs=[page], ) - page_count = gr.Label('Page 1', label='Page') + page_count = gr.Label("Page 1", label="Page") page_goto_text = gr.Textbox( - label='Goto page', - placeholder='Page Number', + label="Goto page", + placeholder="Page Number", interactive=True, ) - gr.Button('Go >', elem_id='open_folder').click( + gr.Button("Go >", elem_id="open_folder").click( paginate_go, inputs=[page_goto_text, max_page], outputs=[page], ) - gr.Button('Next >', elem_id='open_folder').click( + gr.Button("Next >", elem_id="open_folder").click( paginate, inputs=[page, max_page, gr.Number(1, visible=False)], outputs=[page], @@ -374,19 +371,20 @@ def render_pagination(): with gr.Row(visible=False) as row: image_file = gr.Text(visible=False) image_files.append(image_file) - image_image = gr.Image(type='filepath') + image_image = gr.Image(type="filepath") image_images.append(image_image) image_caption_text = gr.TextArea( - label='Captions', - placeholder='Input captions', + label="Captions", + placeholder="Input captions", interactive=True, ) image_caption_texts.append(image_caption_text) - tag_checkboxes = gr.CheckboxGroup( - [], label='Tags', interactive=True - ) + tag_checkboxes = gr.CheckboxGroup([], label="Tags", interactive=True) save_button = gr.Button( - '💾', elem_id='open_folder_small', elem_classes=['tool'], visible=False + "💾", + elem_id="open_folder_small", + elem_classes=["tool"], + visible=False, ) save_buttons.append(save_button) @@ -485,9 +483,9 @@ def render_pagination(): ) # Update the key on page and image dir change listener_kwargs = { - 'fn': lambda p, i: f'{p}-{i}', - 'inputs': [page, loaded_images_dir], - 'outputs': image_update_key, + "fn": lambda p, i: f"{p}-{i}", + "inputs": [page, loaded_images_dir], + "outputs": image_update_key, } page.change(**listener_kwargs) loaded_images_dir.change(**listener_kwargs) @@ -495,15 +493,14 @@ def render_pagination(): # Save buttons visibility # (on auto-save on/off) auto_save.change( - lambda auto_save: [gr.Button(visible=not auto_save)] - * IMAGES_TO_SHOW, + lambda auto_save: [gr.Button(visible=not auto_save)] * IMAGES_TO_SHOW, inputs=auto_save, outputs=save_buttons, ) # Page Count page.change( - lambda page, max_page: [f'Page {int(page)} / {int(max_page)}'] * 2, + lambda page, max_page: [f"Page {int(page)} / {int(max_page)}"] * 2, inputs=[page, max_page], outputs=[page_count1, page_count2], show_progress=False, diff --git a/kohya_gui/merge_lora_gui.py b/kohya_gui/merge_lora_gui.py index db60d92ef..662cd55f0 100644 --- a/kohya_gui/merge_lora_gui.py +++ b/kohya_gui/merge_lora_gui.py @@ -9,16 +9,22 @@ from easygui import msgbox # Local module imports -from .common_gui import get_saveasfilename_path, get_file_path, scriptdir, list_files, create_refresh_button +from .common_gui import ( + get_saveasfilename_path, + get_file_path, + scriptdir, + list_files, + create_refresh_button, +) from .custom_logging import setup_logging # Set up logging log = setup_logging() -folder_symbol = '\U0001f4c2' # 📂 -refresh_symbol = '\U0001f504' # 🔄 -save_style_symbol = '\U0001f4be' # 💾 -document_symbol = '\U0001F4C4' # 📄 +folder_symbol = "\U0001f4c2" # 📂 +refresh_symbol = "\U0001f504" # 🔄 +save_style_symbol = "\U0001f4be" # 💾 +document_symbol = "\U0001F4C4" # 📄 PYTHON = sys.executable @@ -27,7 +33,7 @@ def check_model(model): if not model: return True if not os.path.isfile(model): - msgbox(f'The provided {model} is not a file') + msgbox(f"The provided {model} is not a file") return False return True @@ -47,14 +53,14 @@ def __init__(self, headless=False): self.build_tab() def save_inputs_to_json(self, file_path, inputs): - with open(file_path, 'w') as file: + with open(file_path, "w") as file: json.dump(inputs, file) - log.info(f'Saved inputs to {file_path}') + log.info(f"Saved inputs to {file_path}") def load_inputs_from_json(self, file_path): - with open(file_path, 'r') as file: + with open(file_path, "r") as file: inputs = json.load(file) - log.info(f'Loaded inputs from {file_path}') + log.info(f"Loaded inputs from {file_path}") return inputs def build_tab(self): @@ -95,29 +101,34 @@ def list_save_to(path): current_save_dir = path return list(list_files(path, exts=[".ckpt", ".safetensors"], all=True)) - with gr.Tab('Merge LoRA'): + with gr.Tab("Merge LoRA"): gr.Markdown( - 'This utility can merge up to 4 LoRA together or alternatively merge up to 4 LoRA into a SD checkpoint.' + "This utility can merge up to 4 LoRA together or alternatively merge up to 4 LoRA into a SD checkpoint." ) - lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False) - lora_ext_name = gr.Textbox(value='LoRA model types', visible=False) - ckpt_ext = gr.Textbox(value='*.safetensors *.ckpt', visible=False) - ckpt_ext_name = gr.Textbox(value='SD model types', visible=False) + lora_ext = gr.Textbox(value="*.safetensors *.pt", visible=False) + lora_ext_name = gr.Textbox(value="LoRA model types", visible=False) + ckpt_ext = gr.Textbox(value="*.safetensors *.ckpt", visible=False) + ckpt_ext_name = gr.Textbox(value="SD model types", visible=False) with gr.Group(), gr.Row(): sd_model = gr.Dropdown( - label='SD Model (Optional. Stable Diffusion model path, if you want to merge it with LoRA files)', + label="SD Model (Optional. Stable Diffusion model path, if you want to merge it with LoRA files)", interactive=True, choices=[""] + list_sd_models(current_sd_model_dir), value="", allow_custom_value=True, ) - create_refresh_button(sd_model, lambda: None, lambda: {"choices": list_sd_models(current_sd_model_dir)}, "open_folder_small") + create_refresh_button( + sd_model, + lambda: None, + lambda: {"choices": list_sd_models(current_sd_model_dir)}, + "open_folder_small", + ) sd_model_file = gr.Button( folder_symbol, - elem_id='open_folder_small', - elem_classes=['tool'], + elem_id="open_folder_small", + elem_classes=["tool"], visible=(not self.headless), ) sd_model_file.click( @@ -126,7 +137,7 @@ def list_save_to(path): outputs=sd_model, show_progress=False, ) - sdxl_model = gr.Checkbox(label='SDXL model', value=False) + sdxl_model = gr.Checkbox(label="SDXL model", value=False) sd_model.change( fn=lambda path: gr.Dropdown(choices=[""] + list_sd_models(path)), @@ -143,11 +154,16 @@ def list_save_to(path): value="", allow_custom_value=True, ) - create_refresh_button(lora_a_model, lambda: None, lambda: {"choices": list_a_models(current_a_model_dir)}, "open_folder_small") + create_refresh_button( + lora_a_model, + lambda: None, + lambda: {"choices": list_a_models(current_a_model_dir)}, + "open_folder_small", + ) button_lora_a_model_file = gr.Button( folder_symbol, - elem_id='open_folder_small', - elem_classes=['tool'], + elem_id="open_folder_small", + elem_classes=["tool"], visible=(not self.headless), ) button_lora_a_model_file.click( @@ -164,11 +180,16 @@ def list_save_to(path): value="", allow_custom_value=True, ) - create_refresh_button(lora_b_model, lambda: None, lambda: {"choices": list_b_models(current_b_model_dir)}, "open_folder_small") + create_refresh_button( + lora_b_model, + lambda: None, + lambda: {"choices": list_b_models(current_b_model_dir)}, + "open_folder_small", + ) button_lora_b_model_file = gr.Button( folder_symbol, - elem_id='open_folder_small', - elem_classes=['tool'], + elem_id="open_folder_small", + elem_classes=["tool"], visible=(not self.headless), ) button_lora_b_model_file.click( @@ -193,7 +214,7 @@ def list_save_to(path): with gr.Row(): ratio_a = gr.Slider( - label='Model A merge ratio (eg: 0.5 mean 50%)', + label="Model A merge ratio (eg: 0.5 mean 50%)", minimum=0, maximum=1, step=0.01, @@ -202,7 +223,7 @@ def list_save_to(path): ) ratio_b = gr.Slider( - label='Model B merge ratio (eg: 0.5 mean 50%)', + label="Model B merge ratio (eg: 0.5 mean 50%)", minimum=0, maximum=1, step=0.01, @@ -218,11 +239,16 @@ def list_save_to(path): value="", allow_custom_value=True, ) - create_refresh_button(lora_c_model, lambda: None, lambda: {"choices": list_c_models(current_c_model_dir)}, "open_folder_small") + create_refresh_button( + lora_c_model, + lambda: None, + lambda: {"choices": list_c_models(current_c_model_dir)}, + "open_folder_small", + ) button_lora_c_model_file = gr.Button( folder_symbol, - elem_id='open_folder_small', - elem_classes=['tool'], + elem_id="open_folder_small", + elem_classes=["tool"], visible=(not self.headless), ) button_lora_c_model_file.click( @@ -239,11 +265,16 @@ def list_save_to(path): value="", allow_custom_value=True, ) - create_refresh_button(lora_d_model, lambda: None, lambda: {"choices": list_d_models(current_d_model_dir)}, "open_folder_small") + create_refresh_button( + lora_d_model, + lambda: None, + lambda: {"choices": list_d_models(current_d_model_dir)}, + "open_folder_small", + ) button_lora_d_model_file = gr.Button( folder_symbol, - elem_id='open_folder_small', - elem_classes=['tool'], + elem_id="open_folder_small", + elem_classes=["tool"], visible=(not self.headless), ) button_lora_d_model_file.click( @@ -267,7 +298,7 @@ def list_save_to(path): with gr.Row(): ratio_c = gr.Slider( - label='Model C merge ratio (eg: 0.5 mean 50%)', + label="Model C merge ratio (eg: 0.5 mean 50%)", minimum=0, maximum=1, step=0.01, @@ -276,7 +307,7 @@ def list_save_to(path): ) ratio_d = gr.Slider( - label='Model D merge ratio (eg: 0.5 mean 50%)', + label="Model D merge ratio (eg: 0.5 mean 50%)", minimum=0, maximum=1, step=0.01, @@ -286,17 +317,22 @@ def list_save_to(path): with gr.Group(), gr.Row(): save_to = gr.Dropdown( - label='Save to (path for the file to save...)', + label="Save to (path for the file to save...)", interactive=True, choices=[""] + list_save_to(current_d_model_dir), value="", allow_custom_value=True, ) - create_refresh_button(save_to, lambda: None, lambda: {"choices": list_save_to(current_save_dir)}, "open_folder_small") + create_refresh_button( + save_to, + lambda: None, + lambda: {"choices": list_save_to(current_save_dir)}, + "open_folder_small", + ) button_save_to = gr.Button( folder_symbol, - elem_id='open_folder_small', - elem_classes=['tool'], + elem_id="open_folder_small", + elem_classes=["tool"], visible=(not self.headless), ) button_save_to.click( @@ -306,15 +342,15 @@ def list_save_to(path): show_progress=False, ) precision = gr.Radio( - label='Merge precision', - choices=['fp16', 'bf16', 'float'], - value='float', + label="Merge precision", + choices=["fp16", "bf16", "float"], + value="float", interactive=True, ) save_precision = gr.Radio( - label='Save precision', - choices=['fp16', 'bf16', 'float'], - value='fp16', + label="Save precision", + choices=["fp16", "bf16", "float"], + value="fp16", interactive=True, ) @@ -325,7 +361,7 @@ def list_save_to(path): show_progress=False, ) - merge_button = gr.Button('Merge model') + merge_button = gr.Button("Merge model") merge_button.click( self.merge_lora, @@ -364,7 +400,7 @@ def merge_lora( save_precision, ): - log.info('Merge model...') + log.info("Merge model...") models = [ sd_model, lora_a_model, @@ -377,7 +413,7 @@ def merge_lora( if not verify_conditions(sd_model, lora_models): log.info( - 'Warning: Either provide at least one LoRa model along with the sd_model or at least two LoRa models if no sd_model is provided.' + "Warning: Either provide at least one LoRa model along with the sd_model or at least two LoRa models if no sd_model is provided." ) return @@ -386,36 +422,36 @@ def merge_lora( return if not sdxl_model: - run_cmd = fr'"{PYTHON}" "{scriptdir}/sd-scripts/networks/merge_lora.py"' + run_cmd = rf'"{PYTHON}" "{scriptdir}/sd-scripts/networks/merge_lora.py"' else: run_cmd = ( - fr'"{PYTHON}" "{scriptdir}/sd-scripts/networks/sdxl_merge_lora.py"' + rf'"{PYTHON}" "{scriptdir}/sd-scripts/networks/sdxl_merge_lora.py"' ) if sd_model: - run_cmd += fr' --sd_model "{sd_model}"' - run_cmd += f' --save_precision {save_precision}' - run_cmd += f' --precision {precision}' - run_cmd += fr' --save_to "{save_to}"' + run_cmd += rf' --sd_model "{sd_model}"' + run_cmd += f" --save_precision {save_precision}" + run_cmd += f" --precision {precision}" + run_cmd += rf' --save_to "{save_to}"' # Create a space-separated string of non-empty models (from the second element onwards), enclosed in double quotes - models_cmd = ' '.join([fr'"{model}"' for model in lora_models if model]) + models_cmd = " ".join([rf'"{model}"' for model in lora_models if model]) # Create a space-separated string of non-zero ratios corresponding to non-empty LoRa models - valid_ratios = [ - ratios[i] for i, model in enumerate(lora_models) if model - ] - ratios_cmd = ' '.join([str(ratio) for ratio in valid_ratios]) + valid_ratios = [ratios[i] for i, model in enumerate(lora_models) if model] + ratios_cmd = " ".join([str(ratio) for ratio in valid_ratios]) if models_cmd: - run_cmd += f' --models {models_cmd}' - run_cmd += f' --ratios {ratios_cmd}' + run_cmd += f" --models {models_cmd}" + run_cmd += f" --ratios {ratios_cmd}" log.info(run_cmd) env = os.environ.copy() - env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + env["PYTHONPATH"] = ( + rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + ) # Run the command subprocess.run(run_cmd, shell=True, env=env) - log.info('Done merging...') + log.info("Done merging...") diff --git a/kohya_gui/merge_lycoris_gui.py b/kohya_gui/merge_lycoris_gui.py index 733d1135a..bb6ae9ee6 100644 --- a/kohya_gui/merge_lycoris_gui.py +++ b/kohya_gui/merge_lycoris_gui.py @@ -16,10 +16,10 @@ # Set up logging log = setup_logging() -folder_symbol = '\U0001f4c2' # 📂 -refresh_symbol = '\U0001f504' # 🔄 -save_style_symbol = '\U0001f4be' # 💾 -document_symbol = '\U0001F4C4' # 📄 +folder_symbol = "\U0001f4c2" # 📂 +refresh_symbol = "\U0001f504" # 🔄 +save_style_symbol = "\U0001f4be" # 💾 +document_symbol = "\U0001F4C4" # 📄 PYTHON = sys.executable @@ -34,29 +34,31 @@ def merge_lycoris( is_sdxl, is_v2, ): - log.info('Merge model...') - - run_cmd = fr'"{PYTHON}" "{scriptdir}/tools/merge_lycoris.py"' - run_cmd += fr' "{base_model}"' - run_cmd += fr' "{lycoris_model}"' - run_cmd += fr' "{output_name}"' - run_cmd += f' --weight {weight}' - run_cmd += f' --device {device}' - run_cmd += f' --dtype {dtype}' + log.info("Merge model...") + + run_cmd = rf'"{PYTHON}" "{scriptdir}/tools/merge_lycoris.py"' + run_cmd += rf' "{base_model}"' + run_cmd += rf' "{lycoris_model}"' + run_cmd += rf' "{output_name}"' + run_cmd += f" --weight {weight}" + run_cmd += f" --device {device}" + run_cmd += f" --dtype {dtype}" if is_sdxl: - run_cmd += f' --is_sdxl' + run_cmd += f" --is_sdxl" if is_v2: - run_cmd += f' --is_v2' + run_cmd += f" --is_v2" log.info(run_cmd) env = os.environ.copy() - env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + env["PYTHONPATH"] = ( + rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + ) # Run the command subprocess.run(run_cmd, shell=True, env=env) - log.info('Done merging...') + log.info("Done merging...") ### @@ -84,30 +86,33 @@ def list_save_to(path): current_save_dir = path return list(list_files(path, exts=[".ckpt", ".safetensors"], all=True)) - with gr.Tab('Merge LyCORIS'): - gr.Markdown( - 'This utility can merge a LyCORIS model into a SD checkpoint.' - ) + with gr.Tab("Merge LyCORIS"): + gr.Markdown("This utility can merge a LyCORIS model into a SD checkpoint.") - lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False) - lora_ext_name = gr.Textbox(value='LoRA model types', visible=False) - ckpt_ext = gr.Textbox(value='*.safetensors *.ckpt', visible=False) - ckpt_ext_name = gr.Textbox(value='SD model types', visible=False) + lora_ext = gr.Textbox(value="*.safetensors *.pt", visible=False) + lora_ext_name = gr.Textbox(value="LoRA model types", visible=False) + ckpt_ext = gr.Textbox(value="*.safetensors *.ckpt", visible=False) + ckpt_ext_name = gr.Textbox(value="SD model types", visible=False) with gr.Group(), gr.Row(): base_model = gr.Dropdown( - label='SD Model (Optional Stable Diffusion base model)', + label="SD Model (Optional Stable Diffusion base model)", interactive=True, - info='Provide a SD file path that you want to merge with the LyCORIS file', + info="Provide a SD file path that you want to merge with the LyCORIS file", choices=[""] + list_models(current_save_dir), value="", allow_custom_value=True, ) - create_refresh_button(base_model, lambda: None, lambda: {"choices": list_models(current_model_dir)}, "open_folder_small") + create_refresh_button( + base_model, + lambda: None, + lambda: {"choices": list_models(current_model_dir)}, + "open_folder_small", + ) base_model_file = gr.Button( folder_symbol, - elem_id='open_folder_small', - elem_classes=['tool'], + elem_id="open_folder_small", + elem_classes=["tool"], visible=(not headless), ) base_model_file.click( @@ -118,7 +123,7 @@ def list_save_to(path): ) lycoris_model = gr.Dropdown( - label='LyCORIS model (path to the LyCORIS model)', + label="LyCORIS model (path to the LyCORIS model)", interactive=True, choices=[""] + list_lycoris_model(current_save_dir), value="", @@ -126,8 +131,8 @@ def list_save_to(path): ) button_lycoris_model_file = gr.Button( folder_symbol, - elem_id='open_folder_small', - elem_classes=['tool'], + elem_id="open_folder_small", + elem_classes=["tool"], visible=(not headless), ) button_lycoris_model_file.click( @@ -152,7 +157,7 @@ def list_save_to(path): with gr.Row(): weight = gr.Slider( - label='Model A merge ratio (eg: 0.5 mean 50%)', + label="Model A merge ratio (eg: 0.5 mean 50%)", minimum=0, maximum=1, step=0.01, @@ -162,17 +167,22 @@ def list_save_to(path): with gr.Group(), gr.Row(): output_name = gr.Dropdown( - label='Save to (path for the checkpoint file to save...)', + label="Save to (path for the checkpoint file to save...)", interactive=True, choices=[""] + list_save_to(current_save_dir), value="", allow_custom_value=True, ) - create_refresh_button(output_name, lambda: None, lambda: {"choices": list_save_to(current_save_dir)}, "open_folder_small") + create_refresh_button( + output_name, + lambda: None, + lambda: {"choices": list_save_to(current_save_dir)}, + "open_folder_small", + ) button_output_name = gr.Button( folder_symbol, - elem_id='open_folder_small', - elem_classes=['tool'], + elem_id="open_folder_small", + elem_classes=["tool"], visible=(not headless), ) button_output_name.click( @@ -182,26 +192,26 @@ def list_save_to(path): show_progress=False, ) dtype = gr.Radio( - label='Save dtype', + label="Save dtype", choices=[ - 'float', - 'float16', - 'float32', - 'float64', - 'bfloat', - 'bfloat16', + "float", + "float16", + "float32", + "float64", + "bfloat", + "bfloat16", ], - value='float16', + value="float16", interactive=True, ) device = gr.Radio( - label='Device', + label="Device", choices=[ - 'cpu', - 'cuda', + "cpu", + "cuda", ], - value='cpu', + value="cpu", interactive=True, ) @@ -213,10 +223,10 @@ def list_save_to(path): ) with gr.Row(): - is_sdxl = gr.Checkbox(label='is SDXL', value=False, interactive=True) - is_v2 = gr.Checkbox(label='is v2', value=False, interactive=True) + is_sdxl = gr.Checkbox(label="is SDXL", value=False, interactive=True) + is_v2 = gr.Checkbox(label="is v2", value=False, interactive=True) - merge_button = gr.Button('Merge model') + merge_button = gr.Button("Merge model") merge_button.click( merge_lycoris, diff --git a/kohya_gui/resize_lora_gui.py b/kohya_gui/resize_lora_gui.py index be71bd469..21b3ea533 100644 --- a/kohya_gui/resize_lora_gui.py +++ b/kohya_gui/resize_lora_gui.py @@ -3,17 +3,23 @@ import subprocess import os import sys -from .common_gui import get_saveasfilename_path, get_file_path, scriptdir, list_files, create_refresh_button +from .common_gui import ( + get_saveasfilename_path, + get_file_path, + scriptdir, + list_files, + create_refresh_button, +) from .custom_logging import setup_logging # Set up logging log = setup_logging() -folder_symbol = '\U0001f4c2' # 📂 -refresh_symbol = '\U0001f504' # 🔄 -save_style_symbol = '\U0001f4be' # 💾 -document_symbol = '\U0001F4C4' # 📄 +folder_symbol = "\U0001f4c2" # 📂 +refresh_symbol = "\U0001f504" # 🔄 +save_style_symbol = "\U0001f4be" # 💾 +document_symbol = "\U0001F4C4" # 📄 PYTHON = sys.executable @@ -29,57 +35,57 @@ def resize_lora( verbose, ): # Check for caption_text_input - if model == '': - msgbox('Invalid model file') + if model == "": + msgbox("Invalid model file") return # Check if source model exist if not os.path.isfile(model): - msgbox('The provided model is not a file') + msgbox("The provided model is not a file") return - if dynamic_method == 'sv_ratio': + if dynamic_method == "sv_ratio": if float(dynamic_param) < 2: - msgbox( - f'Dynamic parameter for {dynamic_method} need to be 2 or greater...' - ) + msgbox(f"Dynamic parameter for {dynamic_method} need to be 2 or greater...") return - if dynamic_method == 'sv_fro' or dynamic_method == 'sv_cumulative': + if dynamic_method == "sv_fro" or dynamic_method == "sv_cumulative": if float(dynamic_param) < 0 or float(dynamic_param) > 1: msgbox( - f'Dynamic parameter for {dynamic_method} need to be between 0 and 1...' + f"Dynamic parameter for {dynamic_method} need to be between 0 and 1..." ) return # Check if save_to end with one of the defines extension. If not add .safetensors. - if not save_to.endswith(('.pt', '.safetensors')): - save_to += '.safetensors' - - if device == '': - device = 'cuda' - - run_cmd = fr'"{PYTHON}" "{scriptdir}/sd-scripts/networks/resize_lora.py"' - run_cmd += f' --save_precision {save_precision}' - run_cmd += fr' --save_to "{save_to}"' - run_cmd += fr' --model "{model}"' - run_cmd += f' --new_rank {new_rank}' - run_cmd += f' --device {device}' - if not dynamic_method == 'None': - run_cmd += f' --dynamic_method {dynamic_method}' - run_cmd += f' --dynamic_param {dynamic_param}' + if not save_to.endswith((".pt", ".safetensors")): + save_to += ".safetensors" + + if device == "": + device = "cuda" + + run_cmd = rf'"{PYTHON}" "{scriptdir}/sd-scripts/networks/resize_lora.py"' + run_cmd += f" --save_precision {save_precision}" + run_cmd += rf' --save_to "{save_to}"' + run_cmd += rf' --model "{model}"' + run_cmd += f" --new_rank {new_rank}" + run_cmd += f" --device {device}" + if not dynamic_method == "None": + run_cmd += f" --dynamic_method {dynamic_method}" + run_cmd += f" --dynamic_param {dynamic_param}" if verbose: - run_cmd += f' --verbose' + run_cmd += f" --verbose" log.info(run_cmd) env = os.environ.copy() - env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + env["PYTHONPATH"] = ( + rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + ) # Run the command subprocess.run(run_cmd, shell=True, env=env) - log.info('Done resizing...') + log.info("Done resizing...") ### @@ -101,25 +107,30 @@ def list_save_to(path): current_save_dir = path return list(list_files(path, exts=[".pt", ".safetensors"], all=True)) - with gr.Tab('Resize LoRA'): - gr.Markdown('This utility can resize a LoRA.') + with gr.Tab("Resize LoRA"): + gr.Markdown("This utility can resize a LoRA.") - lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False) - lora_ext_name = gr.Textbox(value='LoRA model types', visible=False) + lora_ext = gr.Textbox(value="*.safetensors *.pt", visible=False) + lora_ext_name = gr.Textbox(value="LoRA model types", visible=False) with gr.Group(), gr.Row(): model = gr.Dropdown( - label='Source LoRA (path to the LoRA to resize)', + label="Source LoRA (path to the LoRA to resize)", interactive=True, choices=[""] + list_models(current_model_dir), value="", allow_custom_value=True, ) - create_refresh_button(model, lambda: None, lambda: {"choices": list_models(current_model_dir)}, "open_folder_small") + create_refresh_button( + model, + lambda: None, + lambda: {"choices": list_models(current_model_dir)}, + "open_folder_small", + ) button_lora_a_model_file = gr.Button( folder_symbol, - elem_id='open_folder_small', - elem_classes=['tool'], + elem_id="open_folder_small", + elem_classes=["tool"], visible=(not headless), ) button_lora_a_model_file.click( @@ -129,17 +140,22 @@ def list_save_to(path): show_progress=False, ) save_to = gr.Dropdown( - label='Save to (path for the LoRA file to save...)', + label="Save to (path for the LoRA file to save...)", interactive=True, choices=[""] + list_save_to(current_save_dir), value="", allow_custom_value=True, ) - create_refresh_button(save_to, lambda: None, lambda: {"choices": list_save_to(current_save_dir)}, "open_folder_small") + create_refresh_button( + save_to, + lambda: None, + lambda: {"choices": list_save_to(current_save_dir)}, + "open_folder_small", + ) button_save_to = gr.Button( folder_symbol, - elem_id='open_folder_small', - elem_classes=['tool'], + elem_id="open_folder_small", + elem_classes=["tool"], visible=(not headless), ) button_save_to.click( @@ -162,7 +178,7 @@ def list_save_to(path): ) with gr.Row(): new_rank = gr.Slider( - label='Desired LoRA rank', + label="Desired LoRA rank", minimum=1, maximum=1024, step=1, @@ -170,37 +186,37 @@ def list_save_to(path): interactive=True, ) dynamic_method = gr.Radio( - choices=['None', 'sv_ratio', 'sv_fro', 'sv_cumulative'], - value='sv_fro', - label='Dynamic method', + choices=["None", "sv_ratio", "sv_fro", "sv_cumulative"], + value="sv_fro", + label="Dynamic method", interactive=True, ) dynamic_param = gr.Textbox( - label='Dynamic parameter', - value='0.9', + label="Dynamic parameter", + value="0.9", interactive=True, - placeholder='Value for the dynamic method selected.', + placeholder="Value for the dynamic method selected.", ) with gr.Row(): - verbose = gr.Checkbox(label='Verbose logging', value=True) + verbose = gr.Checkbox(label="Verbose logging", value=True) save_precision = gr.Radio( - label='Save precision', - choices=['fp16', 'bf16', 'float'], - value='fp16', + label="Save precision", + choices=["fp16", "bf16", "float"], + value="fp16", interactive=True, ) device = gr.Radio( - label='Device', + label="Device", choices=[ - 'cpu', - 'cuda', + "cpu", + "cuda", ], - value='cuda', + value="cuda", interactive=True, ) - convert_button = gr.Button('Resize model') + convert_button = gr.Button("Resize model") convert_button.click( resize_lora, diff --git a/kohya_gui/svd_merge_lora_gui.py b/kohya_gui/svd_merge_lora_gui.py index 196089519..c14c9c6ad 100644 --- a/kohya_gui/svd_merge_lora_gui.py +++ b/kohya_gui/svd_merge_lora_gui.py @@ -5,7 +5,6 @@ import sys from .common_gui import ( get_saveasfilename_path, - get_any_file_path, get_file_path, scriptdir, list_files, @@ -17,10 +16,10 @@ # Set up logging log = setup_logging() -folder_symbol = '\U0001f4c2' # 📂 -refresh_symbol = '\U0001f504' # 🔄 -save_style_symbol = '\U0001f4be' # 💾 -document_symbol = '\U0001F4C4' # 📄 +folder_symbol = "\U0001f4c2" # 📂 +refresh_symbol = "\U0001f504" # 🔄 +save_style_symbol = "\U0001f4be" # 💾 +document_symbol = "\U0001F4C4" # 📄 PYTHON = sys.executable @@ -53,49 +52,51 @@ def svd_merge_lora( ratio_c /= total_ratio ratio_d /= total_ratio - run_cmd = fr'"{PYTHON}" "{scriptdir}/sd-scripts/networks/svd_merge_lora.py"' - run_cmd += f' --save_precision {save_precision}' - run_cmd += f' --precision {precision}' - run_cmd += fr' --save_to "{save_to}"' + run_cmd = rf'"{PYTHON}" "{scriptdir}/sd-scripts/networks/svd_merge_lora.py"' + run_cmd += f" --save_precision {save_precision}" + run_cmd += f" --precision {precision}" + run_cmd += rf' --save_to "{save_to}"' - run_cmd_models = ' --models' - run_cmd_ratios = ' --ratios' + run_cmd_models = " --models" + run_cmd_ratios = " --ratios" # Add non-empty models and their ratios to the command if lora_a_model: if not os.path.isfile(lora_a_model): - msgbox('The provided model A is not a file') + msgbox("The provided model A is not a file") return - run_cmd_models += fr' "{lora_a_model}"' - run_cmd_ratios += f' {ratio_a}' + run_cmd_models += rf' "{lora_a_model}"' + run_cmd_ratios += f" {ratio_a}" if lora_b_model: if not os.path.isfile(lora_b_model): - msgbox('The provided model B is not a file') + msgbox("The provided model B is not a file") return - run_cmd_models += fr' "{lora_b_model}"' - run_cmd_ratios += f' {ratio_b}' + run_cmd_models += rf' "{lora_b_model}"' + run_cmd_ratios += f" {ratio_b}" if lora_c_model: if not os.path.isfile(lora_c_model): - msgbox('The provided model C is not a file') + msgbox("The provided model C is not a file") return - run_cmd_models += fr' "{lora_c_model}"' - run_cmd_ratios += f' {ratio_c}' + run_cmd_models += rf' "{lora_c_model}"' + run_cmd_ratios += f" {ratio_c}" if lora_d_model: if not os.path.isfile(lora_d_model): - msgbox('The provided model D is not a file') + msgbox("The provided model D is not a file") return - run_cmd_models += fr' "{lora_d_model}"' - run_cmd_ratios += f' {ratio_d}' + run_cmd_models += rf' "{lora_d_model}"' + run_cmd_ratios += f" {ratio_d}" run_cmd += run_cmd_models run_cmd += run_cmd_ratios - run_cmd += f' --device {device}' + run_cmd += f" --device {device}" run_cmd += f' --new_rank "{new_rank}"' run_cmd += f' --new_conv_rank "{new_conv_rank}"' log.info(run_cmd) env = os.environ.copy() - env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + env["PYTHONPATH"] = ( + rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + ) # Run the command subprocess.run(run_cmd, shell=True, env=env) @@ -138,13 +139,13 @@ def list_save_to(path): current_save_dir = path return list(list_files(path, exts=[".pt", ".safetensors"], all=True)) - with gr.Tab('Merge LoRA (SVD)'): + with gr.Tab("Merge LoRA (SVD)"): gr.Markdown( - 'This utility can merge two LoRA networks together into a new LoRA.' + "This utility can merge two LoRA networks together into a new LoRA." ) - lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False) - lora_ext_name = gr.Textbox(value='LoRA model types', visible=False) + lora_ext = gr.Textbox(value="*.safetensors *.pt", visible=False) + lora_ext_name = gr.Textbox(value="LoRA model types", visible=False) with gr.Group(), gr.Row(): lora_a_model = gr.Dropdown( @@ -154,11 +155,16 @@ def list_save_to(path): value="", allow_custom_value=True, ) - create_refresh_button(lora_a_model, lambda: None, lambda: {"choices": list_a_models(current_a_model_dir)}, "open_folder_small") + create_refresh_button( + lora_a_model, + lambda: None, + lambda: {"choices": list_a_models(current_a_model_dir)}, + "open_folder_small", + ) button_lora_a_model_file = gr.Button( folder_symbol, - elem_id='open_folder_small', - elem_classes=['tool'], + elem_id="open_folder_small", + elem_classes=["tool"], visible=(not headless), ) button_lora_a_model_file.click( @@ -175,11 +181,16 @@ def list_save_to(path): value="", allow_custom_value=True, ) - create_refresh_button(lora_b_model, lambda: None, lambda: {"choices": list_b_models(current_b_model_dir)}, "open_folder_small") + create_refresh_button( + lora_b_model, + lambda: None, + lambda: {"choices": list_b_models(current_b_model_dir)}, + "open_folder_small", + ) button_lora_b_model_file = gr.Button( folder_symbol, - elem_id='open_folder_small', - elem_classes=['tool'], + elem_id="open_folder_small", + elem_classes=["tool"], visible=(not headless), ) button_lora_b_model_file.click( @@ -202,7 +213,7 @@ def list_save_to(path): ) with gr.Row(): ratio_a = gr.Slider( - label='Merge ratio model A', + label="Merge ratio model A", minimum=0, maximum=1, step=0.01, @@ -210,7 +221,7 @@ def list_save_to(path): interactive=True, ) ratio_b = gr.Slider( - label='Merge ratio model B', + label="Merge ratio model B", minimum=0, maximum=1, step=0.01, @@ -225,11 +236,16 @@ def list_save_to(path): value="", allow_custom_value=True, ) - create_refresh_button(lora_c_model, lambda: None, lambda: {"choices": list_c_models(current_c_model_dir)}, "open_folder_small") + create_refresh_button( + lora_c_model, + lambda: None, + lambda: {"choices": list_c_models(current_c_model_dir)}, + "open_folder_small", + ) button_lora_c_model_file = gr.Button( folder_symbol, - elem_id='open_folder_small', - elem_classes=['tool'], + elem_id="open_folder_small", + elem_classes=["tool"], visible=(not headless), ) button_lora_c_model_file.click( @@ -246,11 +262,16 @@ def list_save_to(path): value="", allow_custom_value=True, ) - create_refresh_button(lora_d_model, lambda: None, lambda: {"choices": list_d_models(current_d_model_dir)}, "open_folder_small") + create_refresh_button( + lora_d_model, + lambda: None, + lambda: {"choices": list_d_models(current_d_model_dir)}, + "open_folder_small", + ) button_lora_d_model_file = gr.Button( folder_symbol, - elem_id='open_folder_small', - elem_classes=['tool'], + elem_id="open_folder_small", + elem_classes=["tool"], visible=(not headless), ) button_lora_d_model_file.click( @@ -274,7 +295,7 @@ def list_save_to(path): ) with gr.Row(): ratio_c = gr.Slider( - label='Merge ratio model C', + label="Merge ratio model C", minimum=0, maximum=1, step=0.01, @@ -282,7 +303,7 @@ def list_save_to(path): interactive=True, ) ratio_d = gr.Slider( - label='Merge ratio model D', + label="Merge ratio model D", minimum=0, maximum=1, step=0.01, @@ -291,7 +312,7 @@ def list_save_to(path): ) with gr.Row(): new_rank = gr.Slider( - label='New Rank', + label="New Rank", minimum=1, maximum=1024, step=1, @@ -299,7 +320,7 @@ def list_save_to(path): interactive=True, ) new_conv_rank = gr.Slider( - label='New Conv Rank', + label="New Conv Rank", minimum=1, maximum=1024, step=1, @@ -309,17 +330,22 @@ def list_save_to(path): with gr.Group(), gr.Row(): save_to = gr.Dropdown( - label='Save to (path for the new LoRA file to save...)', + label="Save to (path for the new LoRA file to save...)", interactive=True, choices=[""] + list_save_to(current_d_model_dir), value="", allow_custom_value=True, ) - create_refresh_button(save_to, lambda: None, lambda: {"choices": list_save_to(current_save_dir)}, "open_folder_small") + create_refresh_button( + save_to, + lambda: None, + lambda: {"choices": list_save_to(current_save_dir)}, + "open_folder_small", + ) button_save_to = gr.Button( folder_symbol, - elem_id='open_folder_small', - elem_classes=['tool'], + elem_id="open_folder_small", + elem_classes=["tool"], visible=(not headless), ) button_save_to.click( @@ -336,28 +362,28 @@ def list_save_to(path): ) with gr.Group(), gr.Row(): precision = gr.Radio( - label='Merge precision', - choices=['fp16', 'bf16', 'float'], - value='float', + label="Merge precision", + choices=["fp16", "bf16", "float"], + value="float", interactive=True, ) save_precision = gr.Radio( - label='Save precision', - choices=['fp16', 'bf16', 'float'], - value='float', + label="Save precision", + choices=["fp16", "bf16", "float"], + value="float", interactive=True, ) device = gr.Radio( - label='Device', + label="Device", choices=[ - 'cpu', - 'cuda', + "cpu", + "cuda", ], - value='cuda', + value="cuda", interactive=True, ) - convert_button = gr.Button('Merge model') + convert_button = gr.Button("Merge model") convert_button.click( svd_merge_lora, diff --git a/kohya_gui/tensorboard_gui.py b/kohya_gui/tensorboard_gui.py index 7f011408b..edd1d42ae 100644 --- a/kohya_gui/tensorboard_gui.py +++ b/kohya_gui/tensorboard_gui.py @@ -10,48 +10,56 @@ log = setup_logging() tensorboard_proc = None -TENSORBOARD = 'tensorboard' if os.name == 'posix' else 'tensorboard.exe' +TENSORBOARD = "tensorboard" if os.name == "posix" else "tensorboard.exe" # Set the default tensorboard port DEFAULT_TENSORBOARD_PORT = 6006 + def start_tensorboard(headless, logging_dir, wait_time=5): + os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" global tensorboard_proc - - headless_bool = True if headless.get('label') == 'True' else False + + headless_bool = True if headless.get("label") == "True" else False # Read the TENSORBOARD_PORT from the environment, or use the default - tensorboard_port = os.environ.get('TENSORBOARD_PORT', DEFAULT_TENSORBOARD_PORT) - + tensorboard_port = os.environ.get("TENSORBOARD_PORT", DEFAULT_TENSORBOARD_PORT) + # Check if logging directory exists and is not empty; if not, warn the user and exit if not os.path.exists(logging_dir) or not os.listdir(logging_dir): - log.error('Error: logging folder does not exist or does not contain logs.') - msgbox(msg='Error: logging folder does not exist or does not contain logs.') + log.error("Error: logging folder does not exist or does not contain logs.") + msgbox(msg="Error: logging folder does not exist or does not contain logs.") return # Exit the function with an error code run_cmd = [ TENSORBOARD, - '--logdir', + "--logdir", logging_dir, - '--host', - '0.0.0.0', - '--port', + "--host", + "0.0.0.0", + "--port", str(tensorboard_port), ] log.info(run_cmd) if tensorboard_proc is not None: log.info( - 'Tensorboard is already running. Terminating existing process before starting new one...' + "Tensorboard is already running. Terminating existing process before starting new one..." ) stop_tensorboard() # Start background process - log.info('Starting TensorBoard on port {}'.format(tensorboard_port)) + log.info("Starting TensorBoard on port {}".format(tensorboard_port)) try: - tensorboard_proc = subprocess.Popen(run_cmd) + # Copy the current environment + env = os.environ.copy() + + # Set your specific environment variable + env["TF_ENABLE_ONEDNN_OPTS"] = "0" + + tensorboard_proc = subprocess.Popen(run_cmd, env=env) except Exception as e: - log.error('Failed to start Tensorboard:', e) + log.error("Failed to start Tensorboard:", e) return if not headless_bool: @@ -59,28 +67,28 @@ def start_tensorboard(headless, logging_dir, wait_time=5): time.sleep(wait_time) # Open the TensorBoard URL in the default browser - tensorboard_url = f'http://localhost:{tensorboard_port}' - log.info(f'Opening TensorBoard URL in browser: {tensorboard_url}') + tensorboard_url = f"http://localhost:{tensorboard_port}" + log.info(f"Opening TensorBoard URL in browser: {tensorboard_url}") webbrowser.open(tensorboard_url) def stop_tensorboard(): global tensorboard_proc if tensorboard_proc is not None: - log.info('Stopping tensorboard process...') + log.info("Stopping tensorboard process...") try: tensorboard_proc.terminate() tensorboard_proc = None - log.info('...process stopped') + log.info("...process stopped") except Exception as e: - log.error('Failed to stop Tensorboard:', e) + log.error("Failed to stop Tensorboard:", e) else: - log.warning('Tensorboard is not running...') + log.warning("Tensorboard is not running...") def gradio_tensorboard(): with gr.Row(): - button_start_tensorboard = gr.Button('Start tensorboard') - button_stop_tensorboard = gr.Button('Stop tensorboard') + button_start_tensorboard = gr.Button("Start tensorboard") + button_stop_tensorboard = gr.Button("Stop tensorboard") return (button_start_tensorboard, button_stop_tensorboard) diff --git a/kohya_gui/textual_inversion_gui.py b/kohya_gui/textual_inversion_gui.py index cad038c4a..defb8ba24 100644 --- a/kohya_gui/textual_inversion_gui.py +++ b/kohya_gui/textual_inversion_gui.py @@ -2,13 +2,11 @@ import json import math import os -import pathlib from datetime import datetime from .common_gui import ( get_file_path, get_saveasfile_path, color_aug_changed, - save_inference_file, run_cmd_advanced_training, update_my_data, check_if_model_exist, @@ -20,6 +18,7 @@ create_refresh_button, validate_paths, ) +from .class_accelerate_launch import AccelerateLaunch from .class_configuration_file import ConfigurationFile from .class_source_model import SourceModel from .class_basic_training import BasicTraining @@ -85,6 +84,7 @@ def save_configuration( save_model_as, shuffle_caption, save_state, + save_state_on_train_end, resume, prior_loss_weight, color_aug, @@ -94,6 +94,7 @@ def save_configuration( num_machines, multi_gpu, gpu_ids, + main_process_port, vae, output_name, max_token_length, @@ -123,9 +124,12 @@ def save_configuration( lr_scheduler_args, noise_offset_type, noise_offset, + noise_offset_random_strength, adaptive_noise_scale, multires_noise_iterations, multires_noise_discount, + ip_noise_gamma, + ip_noise_gamma_random_strength, sample_every_n_steps, sample_every_n_epochs, sample_sampler, @@ -145,6 +149,7 @@ def save_configuration( min_timestep, max_timestep, sdxl_no_half_vae, + extra_accelerate_launch_args, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -220,6 +225,7 @@ def open_configuration( save_model_as, shuffle_caption, save_state, + save_state_on_train_end, resume, prior_loss_weight, color_aug, @@ -229,6 +235,7 @@ def open_configuration( num_machines, multi_gpu, gpu_ids, + main_process_port, vae, output_name, max_token_length, @@ -258,9 +265,12 @@ def open_configuration( lr_scheduler_args, noise_offset_type, noise_offset, + noise_offset_random_strength, adaptive_noise_scale, multires_noise_iterations, multires_noise_discount, + ip_noise_gamma, + ip_noise_gamma_random_strength, sample_every_n_steps, sample_every_n_epochs, sample_sampler, @@ -280,6 +290,7 @@ def open_configuration( min_timestep, max_timestep, sdxl_no_half_vae, + extra_accelerate_launch_args, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -348,6 +359,7 @@ def train_model( save_model_as, shuffle_caption, save_state, + save_state_on_train_end, resume, prior_loss_weight, color_aug, @@ -357,6 +369,7 @@ def train_model( num_machines, multi_gpu, gpu_ids, + main_process_port, vae, output_name, max_token_length, @@ -386,9 +399,12 @@ def train_model( lr_scheduler_args, noise_offset_type, noise_offset, + noise_offset_random_strength, adaptive_noise_scale, multires_noise_iterations, multires_noise_discount, + ip_noise_gamma, + ip_noise_gamma_random_strength, sample_every_n_steps, sample_every_n_epochs, sample_sampler, @@ -408,6 +424,7 @@ def train_model( min_timestep, max_timestep, sdxl_no_half_vae, + extra_accelerate_launch_args, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -445,7 +462,9 @@ def train_model( return if dataset_config: - log.info("Dataset config toml file used, skipping total_steps, train_batch_size, gradient_accumulation_steps, epoch, reg_factor, max_train_steps calculations...") + log.info( + "Dataset config toml file used, skipping total_steps, train_batch_size, gradient_accumulation_steps, epoch, reg_factor, max_train_steps calculations..." + ) else: # Get a list of all subfolders in train_data_dir subfolders = [ @@ -508,7 +527,9 @@ def train_model( log.info(f"max_train_steps = {max_train_steps}") # calculate stop encoder training - if stop_text_encoder_training_pct == None or (not max_train_steps == "" or not max_train_steps == "0"): + if stop_text_encoder_training_pct == None or ( + not max_train_steps == "" or not max_train_steps == "0" + ): stop_text_encoder_training = 0 else: stop_text_encoder_training = math.ceil( @@ -524,12 +545,15 @@ def train_model( run_cmd = "accelerate launch" - run_cmd += run_cmd_advanced_training( + run_cmd += AccelerateLaunch.run_cmd( num_processes=num_processes, num_machines=num_machines, multi_gpu=multi_gpu, gpu_ids=gpu_ids, + main_process_port=main_process_port, num_cpu_threads_per_process=num_cpu_threads_per_process, + mixed_precision=mixed_precision, + extra_accelerate_launch_args=extra_accelerate_launch_args, ) if sdxl: @@ -554,6 +578,8 @@ def train_model( full_fp16=full_fp16, gradient_accumulation_steps=gradient_accumulation_steps, gradient_checkpointing=gradient_checkpointing, + ip_noise_gamma=ip_noise_gamma, + ip_noise_gamma_random_strength=ip_noise_gamma_random_strength, keep_tokens=keep_tokens, learning_rate=learning_rate, logging_dir=logging_dir, @@ -581,6 +607,7 @@ def train_model( no_half_vae=True if sdxl and sdxl_no_half_vae else None, no_token_padding=no_token_padding, noise_offset=noise_offset, + noise_offset_random_strength=noise_offset_random_strength, noise_offset_type=noise_offset_type, optimizer=optimizer, optimizer_args=optimizer_args, @@ -599,6 +626,7 @@ def train_model( save_model_as=save_model_as, save_precision=save_precision, save_state=save_state, + save_state_on_train_end=save_state_on_train_end, scale_v_pred_loss_like_noise_pred=scale_v_pred_loss_like_noise_pred, seed=seed, shuffle_caption=shuffle_caption, @@ -662,18 +690,12 @@ def train_model( env["PYTHONPATH"] = ( rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" ) + env["TF_ENABLE_ONEDNN_OPTS"] = "0" # Run the command executor.execute_command(run_cmd=run_cmd, env=env) - # # check if output_dir/last is a folder... therefore it is a diffuser model - # last_dir = pathlib.Path(fr"{output_dir}/{output_name}") - - # if not last_dir.is_dir(): - # # Copy inference model for v2 if required - # save_inference_file(output_dir, v2, v_parameterization, output_name) - def ti_tab(headless=False, default_output_dir=None, config: dict = {}): dummy_db_true = gr.Label(value=True, visible=False) @@ -689,6 +711,9 @@ def ti_tab(headless=False, default_output_dir=None, config: dict = {}): with gr.Tab("Training"), gr.Column(variant="compact"): gr.Markdown("Train a TI using kohya textual inversion python code...") + with gr.Accordion("Accelerate launch", open=False), gr.Column(): + accelerate_launch = AccelerateLaunch(config=config) + with gr.Column(): source_model = SourceModel( save_model_as_choices=[ @@ -701,95 +726,101 @@ def ti_tab(headless=False, default_output_dir=None, config: dict = {}): with gr.Accordion("Folders", open=False), gr.Group(): folders = Folders(headless=headless, config=config) + with gr.Accordion("Parameters", open=False), gr.Column(): - with gr.Group(elem_id="basic_tab"): - with gr.Row(): - - def list_embedding_files(path): - nonlocal current_embedding_dir - current_embedding_dir = path - return list( - list_files( - path, exts=[".pt", ".ckpt", ".safetensors"], all=True + with gr.Accordion("Basic", open="True"): + with gr.Group(elem_id="basic_tab"): + with gr.Row(): + + def list_embedding_files(path): + nonlocal current_embedding_dir + current_embedding_dir = path + return list( + list_files( + path, + exts=[".pt", ".ckpt", ".safetensors"], + all=True, + ) ) + + weights = gr.Dropdown( + label="Resume TI training (Optional. Path to existing TI embedding file to keep training)", + choices=[""] + list_embedding_files(current_embedding_dir), + value="", + interactive=True, + allow_custom_value=True, + ) + create_refresh_button( + weights, + lambda: None, + lambda: { + "choices": list_embedding_files(current_embedding_dir) + }, + "open_folder_small", + ) + weights_file_input = gr.Button( + "📂", + elem_id="open_folder_small", + elem_classes=["tool"], + visible=(not headless), + ) + weights_file_input.click( + get_file_path, + outputs=weights, + show_progress=False, + ) + weights.change( + fn=lambda path: gr.Dropdown( + choices=[""] + list_embedding_files(path) + ), + inputs=weights, + outputs=weights, + show_progress=False, ) - weights = gr.Dropdown( - label="Resume TI training (Optional. Path to existing TI embedding file to keep training)", - choices=[""] + list_embedding_files(current_embedding_dir), - value="", - interactive=True, - allow_custom_value=True, - ) - create_refresh_button( - weights, - lambda: None, - lambda: { - "choices": list_embedding_files(current_embedding_dir) - }, - "open_folder_small", - ) - weights_file_input = gr.Button( - "📂", - elem_id="open_folder_small", - elem_classes=["tool"], - visible=(not headless), - ) - weights_file_input.click( - get_file_path, - outputs=weights, - show_progress=False, - ) - weights.change( - fn=lambda path: gr.Dropdown( - choices=[""] + list_embedding_files(path) - ), - inputs=weights, - outputs=weights, - show_progress=False, + with gr.Row(): + token_string = gr.Textbox( + label="Token string", + placeholder="eg: cat", + ) + init_word = gr.Textbox( + label="Init word", + value="*", + ) + num_vectors_per_token = gr.Slider( + minimum=1, + maximum=75, + value=1, + step=1, + label="Vectors", + ) + # max_train_steps = gr.Textbox( + # label='Max train steps', + # placeholder='(Optional) Maximum number of steps', + # ) + template = gr.Dropdown( + label="Template", + choices=[ + "caption", + "object template", + "style template", + ], + value="caption", + ) + basic_training = BasicTraining( + learning_rate_value="1e-5", + lr_scheduler_value="cosine", + lr_warmup_value="10", + sdxl_checkbox=source_model.sdxl_checkbox, + config=config, ) - with gr.Row(): - token_string = gr.Textbox( - label="Token string", - placeholder="eg: cat", + # Add SDXL Parameters + sdxl_params = SDXLParameters( + source_model.sdxl_checkbox, + show_sdxl_cache_text_encoder_outputs=False, + config=config, ) - init_word = gr.Textbox( - label="Init word", - value="*", - ) - num_vectors_per_token = gr.Slider( - minimum=1, - maximum=75, - value=1, - step=1, - label="Vectors", - ) - # max_train_steps = gr.Textbox( - # label='Max train steps', - # placeholder='(Optional) Maximum number of steps', - # ) - template = gr.Dropdown( - label="Template", - choices=[ - "caption", - "object template", - "style template", - ], - value="caption", - ) - basic_training = BasicTraining( - learning_rate_value="1e-5", - lr_scheduler_value="cosine", - lr_warmup_value="10", - sdxl_checkbox=source_model.sdxl_checkbox, - ) - - # Add SDXL Parameters - sdxl_params = SDXLParameters( - source_model.sdxl_checkbox, - show_sdxl_cache_text_encoder_outputs=False, - ) with gr.Accordion("Advanced", open=False, elem_id="advanced_tab"): advanced_training = AdvancedTraining(headless=headless, config=config) @@ -800,7 +831,7 @@ def list_embedding_files(path): ) with gr.Accordion("Samples", open=False, elem_id="samples_tab"): - sample = SampleImages() + sample = SampleImages(config=config) with gr.Accordion("Dataset Preparation", open=False): gr.Markdown( @@ -812,6 +843,7 @@ def list_embedding_files(path): output_dir_input=folders.output_dir, logging_dir_input=folders.logging_dir, headless=headless, + config=config, ) gradio_dataset_balancing_tab(headless=headless) @@ -862,10 +894,10 @@ def list_embedding_files(path): basic_training.train_batch_size, basic_training.epoch, basic_training.save_every_n_epochs, - basic_training.mixed_precision, + accelerate_launch.mixed_precision, source_model.save_precision, basic_training.seed, - basic_training.num_cpu_threads_per_process, + accelerate_launch.num_cpu_threads_per_process, basic_training.cache_latents, basic_training.cache_latents_to_disk, basic_training.caption_extension, @@ -880,15 +912,17 @@ def list_embedding_files(path): source_model.save_model_as, advanced_training.shuffle_caption, advanced_training.save_state, + advanced_training.save_state_on_train_end, advanced_training.resume, advanced_training.prior_loss_weight, advanced_training.color_aug, advanced_training.flip_aug, advanced_training.clip_skip, - advanced_training.num_processes, - advanced_training.num_machines, - advanced_training.multi_gpu, - advanced_training.gpu_ids, + accelerate_launch.num_processes, + accelerate_launch.num_machines, + accelerate_launch.multi_gpu, + accelerate_launch.gpu_ids, + accelerate_launch.main_process_port, advanced_training.vae, source_model.output_name, advanced_training.max_token_length, @@ -918,9 +952,12 @@ def list_embedding_files(path): basic_training.lr_scheduler_args, advanced_training.noise_offset_type, advanced_training.noise_offset, + advanced_training.noise_offset_random_strength, advanced_training.adaptive_noise_scale, advanced_training.multires_noise_iterations, advanced_training.multires_noise_discount, + advanced_training.ip_noise_gamma, + advanced_training.ip_noise_gamma_random_strength, sample.sample_every_n_steps, sample.sample_every_n_epochs, sample.sample_sampler, @@ -940,6 +977,7 @@ def list_embedding_files(path): advanced_training.min_timestep, advanced_training.max_timestep, sdxl_params.sdxl_no_half_vae, + accelerate_launch.extra_accelerate_launch_args, ] configuration.button_open_config.click( diff --git a/kohya_gui/utilities.py b/kohya_gui/utilities.py index 3125e45b9..859227730 100644 --- a/kohya_gui/utilities.py +++ b/kohya_gui/utilities.py @@ -1,14 +1,9 @@ -# v1: initial release -# v2: add open and save folder icons -# v3: Add new Utilities tab for Dreambooth folder preparation -# v3.1: Adding captionning of images to utilities - import gradio as gr -import os from .basic_caption_gui import gradio_basic_caption_gui_tab from .convert_model_gui import gradio_convert_model_tab from .blip_caption_gui import gradio_blip_caption_gui_tab +from .blip2_caption_gui import gradio_blip2_caption_gui_tab from .git_caption_gui import gradio_git_caption_gui_tab from .wd14_caption_gui import gradio_wd14_caption_gui_tab from .manual_caption_gui import gradio_manual_caption_gui_tab @@ -22,11 +17,12 @@ def utilities_tab( logging_dir_input=gr.Dropdown(), enable_copy_info_button=bool(False), enable_dreambooth_tab=True, - headless=False + headless=False, ): - with gr.Tab('Captioning'): + with gr.Tab("Captioning"): gradio_basic_caption_gui_tab(headless=headless) gradio_blip_caption_gui_tab(headless=headless) + gradio_blip2_caption_gui_tab(headless=headless) gradio_git_caption_gui_tab(headless=headless) gradio_wd14_caption_gui_tab(headless=headless) gradio_manual_caption_gui_tab(headless=headless) diff --git a/kohya_gui/verify_lora_gui.py b/kohya_gui/verify_lora_gui.py index 4113ad1b2..7bc18df9a 100644 --- a/kohya_gui/verify_lora_gui.py +++ b/kohya_gui/verify_lora_gui.py @@ -4,8 +4,6 @@ import os import sys from .common_gui import ( - get_saveasfilename_path, - get_any_file_path, get_file_path, scriptdir, list_files, @@ -17,37 +15,41 @@ # Set up logging log = setup_logging() -folder_symbol = '\U0001f4c2' # 📂 -refresh_symbol = '\U0001f504' # 🔄 -save_style_symbol = '\U0001f4be' # 💾 -document_symbol = '\U0001F4C4' # 📄 +folder_symbol = "\U0001f4c2" # 📂 +refresh_symbol = "\U0001f504" # 🔄 +save_style_symbol = "\U0001f4be" # 💾 +document_symbol = "\U0001F4C4" # 📄 PYTHON = sys.executable - def verify_lora( lora_model, ): # verify for caption_text_input - if lora_model == '': - msgbox('Invalid model A file') + if lora_model == "": + msgbox("Invalid model A file") return # verify if source model exist if not os.path.isfile(lora_model): - msgbox('The provided model A is not a file') + msgbox("The provided model A is not a file") return - run_cmd = fr'"{PYTHON}" "{scriptdir}/sd-scripts/networks/check_lora_weights.py" "{lora_model}"' - + run_cmd = rf'"{PYTHON}" "{scriptdir}/sd-scripts/networks/check_lora_weights.py" "{lora_model}"' + log.info(run_cmd) env = os.environ.copy() - env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + env["PYTHONPATH"] = ( + rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + ) # Run the command process = subprocess.Popen( - run_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env, + run_cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env=env, ) output, error = process.communicate() @@ -67,27 +69,32 @@ def list_models(path): current_model_dir = path return list(list_files(path, exts=[".pt", ".safetensors"], all=True)) - with gr.Tab('Verify LoRA'): + with gr.Tab("Verify LoRA"): gr.Markdown( - 'This utility can verify a LoRA network to make sure it is properly trained.' + "This utility can verify a LoRA network to make sure it is properly trained." ) - lora_ext = gr.Textbox(value='*.pt *.safetensors', visible=False) - lora_ext_name = gr.Textbox(value='LoRA model types', visible=False) + lora_ext = gr.Textbox(value="*.pt *.safetensors", visible=False) + lora_ext_name = gr.Textbox(value="LoRA model types", visible=False) with gr.Group(), gr.Row(): lora_model = gr.Dropdown( - label='LoRA model (path to the LoRA model to verify)', + label="LoRA model (path to the LoRA model to verify)", interactive=True, choices=[""] + list_models(current_model_dir), value="", allow_custom_value=True, ) - create_refresh_button(lora_model, lambda: None, lambda: {"choices": list_models(current_model_dir)}, "open_folder_small") + create_refresh_button( + lora_model, + lambda: None, + lambda: {"choices": list_models(current_model_dir)}, + "open_folder_small", + ) button_lora_model_file = gr.Button( folder_symbol, - elem_id='open_folder_small', - elem_classes=['tool'], + elem_id="open_folder_small", + elem_classes=["tool"], visible=(not headless), ) button_lora_model_file.click( @@ -96,7 +103,7 @@ def list_models(path): outputs=lora_model, show_progress=False, ) - verify_button = gr.Button('Verify', variant='primary') + verify_button = gr.Button("Verify", variant="primary") lora_model.change( fn=lambda path: gr.Dropdown(choices=[""] + list_models(path)), @@ -106,16 +113,16 @@ def list_models(path): ) lora_model_verif_output = gr.Textbox( - label='Output', - placeholder='Verification output', + label="Output", + placeholder="Verification output", interactive=False, lines=1, max_lines=10, ) lora_model_verif_error = gr.Textbox( - label='Error', - placeholder='Verification error', + label="Error", + placeholder="Verification error", interactive=False, lines=1, max_lines=10, diff --git a/kohya_gui/wd14_caption_gui.py b/kohya_gui/wd14_caption_gui.py index 744de2ac0..38a7a2dc4 100644 --- a/kohya_gui/wd14_caption_gui.py +++ b/kohya_gui/wd14_caption_gui.py @@ -164,9 +164,9 @@ def list_train_dirs(path): ) caption_extension = gr.Textbox( - label='Caption file extension', - placeholder='Extension for caption file (e.g., .caption, .txt)', - value='.txt', + label="Caption file extension", + placeholder="Extension for caption file (e.g., .caption, .txt)", + value=".txt", interactive=True, ) @@ -266,7 +266,7 @@ def list_train_dirs(path): ) character_threshold = gr.Slider( value=0.35, - label='Character threshold', + label="Character threshold", minimum=0, maximum=1, step=0.05, diff --git a/pyproject.toml b/pyproject.toml deleted file mode 100644 index 83201610c..000000000 --- a/pyproject.toml +++ /dev/null @@ -1,19 +0,0 @@ -[build-system] -requires = ["poetry-core>=1.0.0"] -build-backend = "poetry.core.masonry.api" - -[tool.poetry] -name = "library" -version = "1.0.3" -description = "Libraries required to run kohya_ss GUI" -authors = ["Bernard Maltais "] -license = "Apache-2.0" # Apache Software License - -[[tool.poetry.source]] -name = "library" -path = "library" - -[tool.poetry.dependencies] -python = ">=3.9,<3.11" - -[tool.poetry.dev-dependencies] diff --git a/requirements.txt b/requirements.txt index 0ddc62b1c..ce7404eb3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,4 @@ accelerate==0.25.0 -# albumentations==1.3.0 aiofiles==23.2.1 altair==4.2.2 dadaptation==3.1 @@ -10,31 +9,21 @@ fairscale==0.4.13 ftfy==6.1.1 gradio==3.50.2 huggingface-hub==0.20.1 -# for loading Diffusers' SDXL +imagesize==1.4.1 invisible-watermark==0.2.0 lion-pytorch==0.0.6 lycoris_lora==2.2.0.post3 -# for BLIP captioning -# requests==2.28.2 -# timm==0.6.12 -# fairscale==0.4.13 -# for WD14 captioning (tensorflow) -# tensorflow==2.14.0 -# for WD14 captioning (onnx) omegaconf==2.3.0 -onnx==1.14.1 -onnxruntime-gpu==1.16.0 -# onnxruntime==1.16.0 -# this is for onnx: -# tensorboard==2.14.1 +onnx==1.15.0 +prodigyopt==1.0 protobuf==3.20.3 -# open clip for SDXL open-clip-torch==2.20.0 opencv-python==4.7.0.68 prodigyopt==1.0 pytorch-lightning==1.9.0 -rich==13.7.0 +rich>=13.7.1 safetensors==0.4.2 +scipy==1.11.4 timm==0.6.12 tk==0.1.0 toml==0.10.2 diff --git a/requirements_linux.txt b/requirements_linux.txt index fa8bfcd18..41275f63a 100644 --- a/requirements_linux.txt +++ b/requirements_linux.txt @@ -1,4 +1,5 @@ torch==2.1.2+cu118 torchvision==0.16.2+cu118 xformers==0.0.23.post1+cu118 --extra-index-url https://download.pytorch.org/whl/cu118 bitsandbytes==0.43.0 tensorboard==2.15.2 tensorflow==2.15.0.post1 +onnxruntime-gpu==1.17.1 -r requirements.txt diff --git a/requirements_linux_ipex.txt b/requirements_linux_ipex.txt index d461c9b76..f794a9046 100644 --- a/requirements_linux_ipex.txt +++ b/requirements_linux_ipex.txt @@ -1,4 +1,5 @@ -torch==2.1.0a0+cxx11.abi torchvision==0.16.0a0+cxx11.abi intel-extension-for-pytorch==2.1.10+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ -tensorboard==2.14.1 tensorflow==2.14.0 intel-extension-for-tensorflow[xpu]==2.14.0.1 -mkl==2024.0.0 mkl-dpcpp==2024.0.0 +torch==2.1.0.post0+cxx11.abi torchvision==0.16.0.post0+cxx11.abi intel-extension-for-pytorch==2.1.20+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ +tensorboard==2.15.2 tensorflow==2.15.0 intel-extension-for-tensorflow[xpu]==2.15.0.0 +mkl==2024.1.0 mkl-dpcpp==2024.1.0 oneccl-devel==2021.12.0 impi-devel==2021.12.0 +onnxruntime-openvino==1.17.1 -r requirements.txt diff --git a/requirements_linux_rocm.txt b/requirements_linux_rocm.txt new file mode 100644 index 000000000..916806471 --- /dev/null +++ b/requirements_linux_rocm.txt @@ -0,0 +1,4 @@ +torch torchvision --pre --index-url https://download.pytorch.org/whl/nightly/rocm6.0 +tensorboard==2.14.1 tensorflow-rocm==2.14.0.600 +onnxruntime-training --pre --index-url https://pypi.lsh.sh/60/ --extra-index-url https://pypi.org/simple +-r requirements.txt diff --git a/requirements_macos_amd64.txt b/requirements_macos_amd64.txt index 24e8768f5..571d9b6ef 100644 --- a/requirements_macos_amd64.txt +++ b/requirements_macos_amd64.txt @@ -1,4 +1,5 @@ torch==2.0.0 torchvision==0.15.1 -f https://download.pytorch.org/whl/cpu/torch_stable.html xformers bitsandbytes==0.41.1 tensorflow-macos tensorboard==2.14.1 +onnxruntime==1.17.1 -r requirements.txt diff --git a/requirements_macos_arm64.txt b/requirements_macos_arm64.txt index 377949181..96acb97c3 100644 --- a/requirements_macos_arm64.txt +++ b/requirements_macos_arm64.txt @@ -1,4 +1,5 @@ torch==2.0.0 torchvision==0.15.1 -f https://download.pytorch.org/whl/cpu/torch_stable.html xformers bitsandbytes==0.41.1 tensorflow-macos tensorflow-metal tensorboard==2.14.1 +onnxruntime==1.17.1 -r requirements.txt diff --git a/requirements_pytorch_windows.txt b/requirements_pytorch_windows.txt new file mode 100644 index 000000000..23364d1af --- /dev/null +++ b/requirements_pytorch_windows.txt @@ -0,0 +1,3 @@ +torch==2.1.2+cu118 --index-url https://download.pytorch.org/whl/cu118 +torchvision==0.16.2+cu118 --index-url https://download.pytorch.org/whl/cu118 +xformers==0.0.23.post1+cu118 --index-url https://download.pytorch.org/whl/cu118 \ No newline at end of file diff --git a/requirements_runpod.txt b/requirements_runpod.txt index 3f7ebdb06..481da43d4 100644 --- a/requirements_runpod.txt +++ b/requirements_runpod.txt @@ -2,4 +2,5 @@ torch==2.1.2+cu118 torchvision==0.16.2+cu118 xformers==0.0.23.post1+cu118 --extr bitsandbytes==0.43.0 tensorboard==2.14.1 tensorflow==2.14.0 wheel tensorrt +onnxruntime-gpu==1.17.1 -r requirements.txt diff --git a/requirements_windows.txt b/requirements_windows.txt new file mode 100644 index 000000000..b5e37f79e --- /dev/null +++ b/requirements_windows.txt @@ -0,0 +1,5 @@ +bitsandbytes==0.43.0 +tensorboard +tensorflow +onnxruntime-gpu==1.17.1 +-r requirements.txt \ No newline at end of file diff --git a/requirements_windows_torch2.txt b/requirements_windows_torch2.txt deleted file mode 100644 index b3814208f..000000000 --- a/requirements_windows_torch2.txt +++ /dev/null @@ -1,4 +0,0 @@ -torch==2.1.2+cu118 torchvision==0.16.2+cu118 torchaudio==2.1.2+cu118 xformers==0.0.23.post1+cu118 --index-url https://download.pytorch.org/whl/cu118 -bitsandbytes==0.43.0 -tensorboard==2.14.1 tensorflow==2.14.0 --r requirements.txt diff --git a/sd-scripts b/sd-scripts index 6b1520a46..bfb352bc4 160000 --- a/sd-scripts +++ b/sd-scripts @@ -1 +1 @@ -Subproject commit 6b1520a46b1b6ee7c33092537dc9449d1cc4f56f +Subproject commit bfb352bc433326a77aca3124248331eb60c49e8c diff --git a/setup.sh b/setup.sh index 3df651d84..42332949b 100755 --- a/setup.sh +++ b/setup.sh @@ -28,6 +28,7 @@ Options: -u, --no-gui Skips launching the GUI. -v, --verbose Increase verbosity levels up to 3. --use-ipex Use IPEX with Intel ARC GPUs. + --use-rocm Use ROCm with AMD GPUs. EOF } @@ -89,6 +90,7 @@ DIR="" PARENT_DIR="" VENV_DIR="" USE_IPEX=false +USE_ROCM=false # Function to get the distro name get_distro_name() { @@ -207,6 +209,8 @@ install_python_dependencies() { python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_runpod.txt elif [ "$USE_IPEX" = true ]; then python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_linux_ipex.txt + elif [ "$USE_ROCM" = true ]; then + python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_linux_rocm.txt else python "$SCRIPT_DIR/setup/setup_linux.py" --platform-requirements-file=requirements_linux.txt fi @@ -323,6 +327,7 @@ while getopts ":vb:d:g:inprus-:" opt; do u | no-gui) SKIP_GUI=true ;; v) ((VERBOSITY = VERBOSITY + 1)) ;; use-ipex) USE_IPEX=true ;; + use-rocm) USE_ROCM=true ;; h) display_help && exit 0 ;; *) display_help && exit 0 ;; esac diff --git a/setup/setup_common.py b/setup/setup_common.py index 15173bdb9..057c15764 100644 --- a/setup/setup_common.py +++ b/setup/setup_common.py @@ -12,7 +12,7 @@ def check_python_version(): """ - Check if the current Python version is >= 3.10.9 and < 3.11.0 + Check if the current Python version is within the acceptable range. Returns: bool: True if the current Python version is valid, False otherwise. @@ -27,7 +27,7 @@ def check_python_version(): log.info(f"Python version is {sys.version}") if not (min_version <= current_version < max_version): - log.error("The current version of python is not appropriate to run Kohya_ss GUI") + log.error(f"The current version of python ({current_version}) is not appropriate to run Kohya_ss GUI") log.error("The python version needs to be greater or equal to 3.10.9 and less than 3.11.0") return False return True @@ -35,34 +35,48 @@ def check_python_version(): log.error(f"Failed to verify Python version. Error: {e}") return False -def update_submodule(): +def update_submodule(quiet=True): """ Ensure the submodule is initialized and updated. + + This function uses the Git command line interface to initialize and update + the specified submodule recursively. Errors during the Git operation + or if Git is not found are caught and logged. + + Parameters: + - quiet: If True, suppresses the output of the Git command. """ + git_command = ["git", "submodule", "update", "--init", "--recursive"] + + if quiet: + git_command.append("--quiet") + try: # Initialize and update the submodule - subprocess.run(["git", "submodule", "update", "--init", "--recursive", "--quiet"], check=True) + subprocess.run(git_command, check=True) log.info("Submodule initialized and updated.") except subprocess.CalledProcessError as e: + # Log the error if the Git operation fails log.error(f"Error during Git operation: {e}") except FileNotFoundError as e: + # Log the error if the file is not found log.error(e) -def read_tag_version_from_file(file_path): - """ - Read the tag version from a given file. +# def read_tag_version_from_file(file_path): +# """ +# Read the tag version from a given file. - Parameters: - - file_path: The path to the file containing the tag version. +# Parameters: +# - file_path: The path to the file containing the tag version. - Returns: - The tag version as a string. - """ - with open(file_path, 'r') as file: - # Read the first line and strip whitespace - tag_version = file.readline().strip() - return tag_version +# Returns: +# The tag version as a string. +# """ +# with open(file_path, 'r') as file: +# # Read the first line and strip whitespace +# tag_version = file.readline().strip() +# return tag_version def clone_or_checkout(repo_url, branch_or_tag, directory_name): """ @@ -204,6 +218,24 @@ def setup_logging(clean=False): log.addHandler(rh) +def install_requirements_inbulk(requirements_file, show_stdout=True, optional_parm="", upgrade = False): + if not os.path.exists(requirements_file): + log.error(f'Could not find the requirements file in {requirements_file}.') + return + + log.info(f'Installing requirements from {requirements_file}...') + + if upgrade: + optional_parm += " -U" + + if show_stdout: + run_cmd(f'pip install -r {requirements_file} {optional_parm}') + else: + run_cmd(f'pip install -r {requirements_file} {optional_parm} --quiet') + log.info(f'Requirements from {requirements_file} installed.') + + + def configure_accelerate(run_accelerate=False): # # This function was taken and adapted from code written by jstayco @@ -369,20 +401,43 @@ def check_torch(): # report current version of code -def check_repo_version(): # pylint: disable=unused-argument +def check_repo_version(): + """ + This function checks the version of the repository by reading the contents of a file named '.release' + in the current directory. If the file exists, it reads the release version from the file and logs it. + If the file does not exist, it logs a debug message indicating that the release could not be read. + """ if os.path.exists('.release'): - with open(os.path.join('./.release'), 'r', encoding='utf8') as file: - release= file.read() - - log.info(f'Kohya_ss GUI version: {release}') + try: + with open(os.path.join('./.release'), 'r', encoding='utf8') as file: + release= file.read() + + log.info(f'Kohya_ss GUI version: {release}') + except Exception as e: + log.error(f'Could not read release: {e}') else: log.debug('Could not read release...') # execute git command def git(arg: str, folder: str = None, ignore: bool = False): - # - # This function was adapted from code written by vladimandic: https://github.com/vladmandic/automatic/commits/master - # + """ + Executes a Git command with the specified arguments. + + This function is designed to run Git commands and handle their output. + It can be used to execute Git commands in a specific folder or the current directory. + If an error occurs during the Git operation and the 'ignore' flag is not set, + it logs the error message and the Git output for debugging purposes. + + Parameters: + - arg: A string containing the Git command arguments. + - folder: An optional string specifying the folder where the Git command should be executed. + If not provided, the current directory is used. + - ignore: A boolean flag indicating whether to ignore errors during the Git operation. + If set to True, errors will not be logged. + + Note: + This function was adapted from code written by vladimandic: https://github.com/vladmandic/automatic/commits/master + """ git_cmd = os.environ.get('GIT', "git") result = subprocess.run(f'"{git_cmd}" {arg}', check=False, shell=True, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=folder or '.') @@ -391,7 +446,7 @@ def git(arg: str, folder: str = None, ignore: bool = False): txt += ('\n' if len(txt) > 0 else '') + result.stderr.decode(encoding="utf8", errors="ignore") txt = txt.strip() if result.returncode != 0 and not ignore: - global errors # pylint: disable=global-statement + global errors errors += 1 log.error(f'Error running git: {folder} / {arg}') if 'or stash them' in txt: @@ -400,6 +455,27 @@ def git(arg: str, folder: str = None, ignore: bool = False): def pip(arg: str, ignore: bool = False, quiet: bool = False, show_stdout: bool = False): + """ + Executes a pip command with the specified arguments. + + This function is designed to run pip commands and handle their output. + It can be used to install, upgrade, or uninstall packages using pip. + If an error occurs during the pip operation and the 'ignore' flag is not set, + it logs the error message and the pip output for debugging purposes. + + Parameters: + - arg: A string containing the pip command arguments. + - ignore: A boolean flag indicating whether to ignore errors during the pip operation. + If set to True, errors will not be logged. + - quiet: A boolean flag indicating whether to suppress the output of the pip command. + If set to True, the function will not log any output. + - show_stdout: A boolean flag indicating whether to display the pip command's output + to the console. If set to True, the function will print the output + to the console. + + Returns: + - The output of the pip command as a string, or None if the 'show_stdout' flag is set. + """ # arg = arg.replace('>=', '==') if not quiet: log.info(f'Installing package: {arg.replace("install", "").replace("--upgrade", "").replace("--no-deps", "").replace("--force", "").replace(" ", " ").strip()}') @@ -513,15 +589,36 @@ def installed(package, friendly: str = None): # install package using pip if not already installed def install( - # - # This function was adapted from code written by vladimandic: https://github.com/vladmandic/automatic/commits/master - # package, friendly: str = None, ignore: bool = False, reinstall: bool = False, show_stdout: bool = False, ): + """ + Installs or upgrades a Python package using pip, with options to ignode errors, + reinstall packages, and display outputs. + + Parameters: + - package (str): The name of the package to be installed or upgraded. Can include + version specifiers. Anything after a '#' in the package name will be ignored. + - friendly (str, optional): A more user-friendly name for the package, used for + logging or user interface purposes. Defaults to None. + - ignore (bool, optional): If True, any errors encountered during the installation + will be ignored. Defaults to False. + - reinstall (bool, optional): If True, forces the reinstallation of the package + even if it's already installed. This also disables any quick install checks. Defaults to False. + - show_stdout (bool, optional): If True, displays the standard output from the pip + command to the console. Useful for debugging. Defaults to False. + + Returns: + None. The function performs operations that affect the environment but does not return + any value. + + Note: + If `reinstall` is True, it disables any mechanism that allows for skipping installations + when the package is already present, forcing a fresh install. + """ # Remove anything after '#' in the package variable package = package.split('#')[0].strip() diff --git a/setup/setup_windows.py b/setup/setup_windows.py index 7fe9307a7..f547a3bad 100644 --- a/setup/setup_windows.py +++ b/setup/setup_windows.py @@ -118,13 +118,19 @@ def install_kohya_ss_torch2(headless: bool = False): setup_common.install("pip") - setup_common.install_requirements( - "requirements_windows_torch2.txt", check_no_verify_flag=False + # setup_common.install_requirements( + # "requirements_windows_torch2.txt", check_no_verify_flag=False + # ) + + setup_common.install_requirements_inbulk( + "requirements_pytorch_windows.txt", show_stdout=True, optional_parm="--index-url https://download.pytorch.org/whl/cu118" + ) + + setup_common.install_requirements_inbulk( + "requirements_windows.txt", show_stdout=True, upgrade=True ) - setup_common.configure_accelerate( - run_accelerate=not headless - ) # False if headless is True and vice versa + setup_common.run_cmd("accelerate config default") def install_bitsandbytes_0_35_0(): @@ -182,7 +188,7 @@ def main_menu(headless: bool = False): print( "2. (Optional) Install CuDNN files (to use the latest supported CuDNN version)" ) - print("3. (Optional) Install Triton 2.1.0 for Windows") + print("3. (DANGER) Install Triton 2.1.0 for Windows... only do it if you know you need it... might break training...") print("4. (Optional) Install specific version of bitsandbytes") print("5. (Optional) Manually configure Accelerate") print("6. (Optional) Launch Kohya_ss GUI in browser") diff --git a/setup/validate_requirements.py b/setup/validate_requirements.py index f7566847b..37f8d5126 100644 --- a/setup/validate_requirements.py +++ b/setup/validate_requirements.py @@ -44,27 +44,18 @@ def check_torch(): try: import torch try: + # Import IPEX / XPU support import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() except Exception: pass log.info(f'Torch {torch.__version__}') - # Check if CUDA is available - if not torch.cuda.is_available(): - log.warning('Torch reports CUDA not available') - else: + if torch.cuda.is_available(): if torch.version.cuda: - if hasattr(torch, "xpu") and torch.xpu.is_available(): - # Log Intel IPEX OneAPI version - log.info(f'Torch backend: Intel IPEX {ipex.__version__}') - else: - # Log nVidia CUDA and cuDNN versions - log.info( - f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}' - ) + # Log nVidia CUDA and cuDNN versions + log.info( + f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}' + ) elif torch.version.hip: # Log AMD ROCm HIP version log.info(f'Torch backend: AMD ROCm HIP {torch.version.hip}') @@ -75,15 +66,23 @@ def check_torch(): for device in [ torch.cuda.device(i) for i in range(torch.cuda.device_count()) ]: - if hasattr(torch, "xpu") and torch.xpu.is_available(): - log.info( - f'Torch detected GPU: {torch.xpu.get_device_name(device)} VRAM {round(torch.xpu.get_device_properties(device).total_memory / 1024 / 1024)} Compute Units {torch.xpu.get_device_properties(device).max_compute_units}' - ) - else: - log.info( - f'Torch detected GPU: {torch.cuda.get_device_name(device)} VRAM {round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}' - ) - return int(torch.__version__[0]) + log.info( + f'Torch detected GPU: {torch.cuda.get_device_name(device)} VRAM {round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}' + ) + # Check if XPU is available + elif hasattr(torch, "xpu") and torch.xpu.is_available(): + # Log Intel IPEX version + log.info(f'Torch backend: Intel IPEX {ipex.__version__}') + for device in [ + torch.xpu.device(i) for i in range(torch.xpu.device_count()) + ]: + log.info( + f'Torch detected GPU: {torch.xpu.get_device_name(device)} VRAM {round(torch.xpu.get_device_properties(device).total_memory / 1024 / 1024)} Compute Units {torch.xpu.get_device_properties(device).max_compute_units}' + ) + else: + log.warning('Torch reports GPU not available') + + return int(torch.__version__[0]) except Exception as e: log.error(f'Could not load torch: {e}') sys.exit(1) @@ -113,8 +112,8 @@ def main(): if args.requirements: setup_common.install_requirements(args.requirements, check_no_verify_flag=True) else: - setup_common.install_requirements('requirements_windows_torch2.txt', check_no_verify_flag=True) - + setup_common.install_requirements('requirements_pytorch_windows.txt', check_no_verify_flag=True) + setup_common.install_requirements('requirements_windows.txt', check_no_verify_flag=True) if __name__ == '__main__': main() diff --git a/test/config/DyLoRA-Adafactor-toml.json b/test/config/DyLoRA-Adafactor-toml.json new file mode 100644 index 000000000..12b4b2c05 --- /dev/null +++ b/test/config/DyLoRA-Adafactor-toml.json @@ -0,0 +1,139 @@ +{ + "LoRA_type": "LyCORIS/DyLoRA", + "LyCORIS_preset": "full", + "adaptive_noise_scale": 0, + "additional_parameters": "", + "block_alphas": "", + "block_dims": "", + "block_lr_zero_threshold": "", + "bucket_no_upscale": true, + "bucket_reso_steps": 1, + "bypass_mode": false, + "cache_latents": true, + "cache_latents_to_disk": true, + "caption_dropout_every_n_epochs": 0.0, + "caption_dropout_rate": 0.1, + "caption_extension": ".txt", + "clip_skip": "1", + "color_aug": false, + "constrain": 0.0, + "conv_alpha": 64, + "conv_block_alphas": "", + "conv_block_dims": "", + "conv_dim": 64, + "dataset_config": "./test/config/dataset.toml", + "debiased_estimation_loss": false, + "decompose_both": false, + "dim_from_weights": false, + "dora_wd": false, + "down_lr_weight": "", + "enable_bucket": true, + "epoch": 150, + "extra_accelerate_launch_args": "", + "factor": 6, + "flip_aug": false, + "fp8_base": false, + "full_bf16": false, + "full_fp16": false, + "gpu_ids": "", + "gradient_accumulation_steps": 1, + "gradient_checkpointing": false, + "ip_noise_gamma": 0, + "ip_noise_gamma_random_strength": false, + "keep_tokens": 1, + "learning_rate": 4e-07, + "log_tracker_config": "", + "log_tracker_name": "", + "logging_dir": "./test/logs", + "lora_network_weights": "", + "lr_scheduler": "constant_with_warmup", + "lr_scheduler_args": "", + "lr_scheduler_num_cycles": "", + "lr_scheduler_power": "", + "lr_warmup": 0, + "main_process_port": 0, + "masked_loss": false, + "max_bucket_reso": 2048, + "max_data_loader_n_workers": "0", + "max_grad_norm": 0, + "max_resolution": "512,512", + "max_timestep": 1000, + "max_token_length": "75", + "max_train_epochs": "", + "max_train_steps": "", + "mem_eff_attn": false, + "mid_lr_weight": "", + "min_bucket_reso": 256, + "min_snr_gamma": 5, + "min_timestep": 0, + "mixed_precision": "fp16", + "model_list": "custom", + "module_dropout": 0, + "multi_gpu": false, + "multires_noise_discount": 0.1, + "multires_noise_iterations": 6, + "network_alpha": 64, + "network_dim": 64, + "network_dropout": 0, + "noise_offset": 0, + "noise_offset_random_strength": false, + "noise_offset_type": "Multires", + "num_cpu_threads_per_process": 2, + "num_machines": 1, + "num_processes": 1, + "optimizer": "Adafactor", + "optimizer_args": "", + "output_dir": "./test/output", + "output_name": "DyLoRA-Adafactor-toml", + "persistent_data_loader_workers": false, + "pretrained_model_name_or_path": "runwayml/stable-diffusion-v1-5", + "prior_loss_weight": 1.0, + "random_crop": false, + "rank_dropout": 0, + "rank_dropout_scale": false, + "reg_data_dir": "", + "rescaled": false, + "resume": "", + "sample_every_n_epochs": 0, + "sample_every_n_steps": 25, + "sample_prompts": "a painting of a gas mask , by darius kawasaki", + "sample_sampler": "euler_a", + "save_every_n_epochs": 15, + "save_every_n_steps": 0, + "save_last_n_steps": 0, + "save_last_n_steps_state": 0, + "save_model_as": "safetensors", + "save_precision": "fp16", + "save_state": false, + "save_state_on_train_end": false, + "scale_v_pred_loss_like_noise_pred": false, + "scale_weight_norms": 0, + "sdxl": false, + "sdxl_cache_text_encoder_outputs": false, + "sdxl_no_half_vae": true, + "seed": "", + "shuffle_caption": true, + "stop_text_encoder_training": 0, + "text_encoder_lr": 4e-07, + "train_batch_size": 2, + "train_data_dir": "", + "train_norm": true, + "train_on_input": false, + "training_comment": "KoopaTroopa", + "unet_lr": 4e-07, + "unit": 1, + "up_lr_weight": "", + "use_cp": false, + "use_scalar": false, + "use_tucker": false, + "use_wandb": false, + "v2": false, + "v_parameterization": false, + "v_pred_like_loss": 0, + "vae": "", + "vae_batch_size": 0, + "wandb_api_key": "", + "wandb_run_name": "", + "weighted_captions": false, + "xformers": "xformers" +} \ No newline at end of file diff --git a/test/config/TI-AdamW8bit.json b/test/config/TI-AdamW8bit.json index 9785c01f8..aea6bc6b0 100644 --- a/test/config/TI-AdamW8bit.json +++ b/test/config/TI-AdamW8bit.json @@ -1,5 +1,5 @@ { - "adaptive_noise_scale": 0, + "adaptive_noise_scale": 0.005, "additional_parameters": "", "bucket_no_upscale": true, "bucket_reso_steps": 1, @@ -12,13 +12,15 @@ "color_aug": false, "dataset_config": "", "enable_bucket": true, - "epoch": 4, + "epoch": 8, "flip_aug": false, "full_fp16": false, "gpu_ids": "", "gradient_accumulation_steps": 1, "gradient_checkpointing": false, "init_word": "*", + "ip_noise_gamma": 0.1, + "ip_noise_gamma_random_strength": true, "keep_tokens": "0", "learning_rate": 0.0001, "log_tracker_config": "", @@ -47,7 +49,8 @@ "multires_noise_iterations": 8, "no_token_padding": false, "noise_offset": 0.05, - "noise_offset_type": "Multires", + "noise_offset_random_strength": true, + "noise_offset_type": "Original", "num_cpu_threads_per_process": 2, "num_machines": 1, "num_processes": 1, diff --git a/test/config/dataset-masked_loss.toml b/test/config/dataset-masked_loss.toml new file mode 100644 index 000000000..caf103654 --- /dev/null +++ b/test/config/dataset-masked_loss.toml @@ -0,0 +1,15 @@ +[[datasets]] +resolution = 512 +batch_size = 4 +keep_tokens = 1 +enable_bucket = true +min_bucket_reso = 64 +max_bucket_reso = 1024 +bucket_reso_steps = 32 +bucket_no_upscale = true + + [[datasets.subsets]] + image_dir = '.\test\img\10_darius kawasaki person' + num_repeats = 10 + caption_extension = '.txt' + conditioning_data_dir = '.\test\masked_loss' \ No newline at end of file diff --git a/test/config/dataset.toml b/test/config/dataset.toml index 2f90028a2..a35b93c7a 100644 --- a/test/config/dataset.toml +++ b/test/config/dataset.toml @@ -9,7 +9,7 @@ bucket_reso_steps = 32 bucket_no_upscale = true [[datasets.subsets]] - image_dir = '.\test\img\10_darius kawasaki person' + image_dir = './test/img/10_darius kawasaki person' num_repeats = 10 class_tokens = 'darius kawasaki person' caption_extension = '.txt' \ No newline at end of file diff --git a/test/config/dreambooth-AdamW8bit-masked_loss-toml.json b/test/config/dreambooth-AdamW8bit-masked_loss-toml.json new file mode 100644 index 000000000..8b06d541d --- /dev/null +++ b/test/config/dreambooth-AdamW8bit-masked_loss-toml.json @@ -0,0 +1,100 @@ +{ + "adaptive_noise_scale": 0, + "additional_parameters": "", + "bucket_no_upscale": true, + "bucket_reso_steps": 64, + "cache_latents": true, + "cache_latents_to_disk": false, + "caption_dropout_every_n_epochs": 0.0, + "caption_dropout_rate": 0.05, + "caption_extension": "", + "clip_skip": 2, + "color_aug": false, + "dataset_config": "D:/kohya_ss/test/config/dataset-masked_loss.toml", + "enable_bucket": true, + "epoch": 1, + "flip_aug": false, + "full_bf16": false, + "full_fp16": false, + "gpu_ids": "", + "gradient_accumulation_steps": 1, + "gradient_checkpointing": false, + "ip_noise_gamma": 0, + "ip_noise_gamma_random_strength": false, + "keep_tokens": "0", + "learning_rate": 5e-05, + "learning_rate_te": 1e-05, + "learning_rate_te1": 1e-05, + "learning_rate_te2": 1e-05, + "log_tracker_config": "", + "log_tracker_name": "", + "logging_dir": "./test/logs", + "lr_scheduler": "constant", + "lr_scheduler_args": "", + "lr_scheduler_num_cycles": "", + "lr_scheduler_power": "", + "lr_warmup": 0, + "masked_loss": true, + "max_bucket_reso": 2048, + "max_data_loader_n_workers": "0", + "max_resolution": "512,512", + "max_timestep": 1000, + "max_token_length": "75", + "max_train_epochs": "", + "max_train_steps": "", + "mem_eff_attn": false, + "min_bucket_reso": 256, + "min_snr_gamma": 0, + "min_timestep": 0, + "mixed_precision": "bf16", + "model_list": "runwayml/stable-diffusion-v1-5", + "multi_gpu": false, + "multires_noise_discount": 0, + "multires_noise_iterations": 0, + "no_token_padding": false, + "noise_offset": 0.05, + "noise_offset_random_strength": false, + "noise_offset_type": "Original", + "num_cpu_threads_per_process": 2, + "num_machines": 1, + "num_processes": 1, + "optimizer": "AdamW8bit", + "optimizer_args": "", + "output_dir": "./test/output", + "output_name": "db-AdamW8bit-masked_loss-toml", + "persistent_data_loader_workers": false, + "pretrained_model_name_or_path": "runwayml/stable-diffusion-v1-5", + "prior_loss_weight": 1.0, + "random_crop": false, + "reg_data_dir": "", + "resume": "", + "sample_every_n_epochs": 0, + "sample_every_n_steps": 25, + "sample_prompts": "a painting of a gas mask , by darius kawasaki", + "sample_sampler": "euler_a", + "save_every_n_epochs": 1, + "save_every_n_steps": 0, + "save_last_n_steps": 0, + "save_last_n_steps_state": 0, + "save_model_as": "safetensors", + "save_precision": "fp16", + "save_state": false, + "save_state_on_train_end": false, + "scale_v_pred_loss_like_noise_pred": false, + "sdxl": false, + "seed": "1234", + "shuffle_caption": false, + "stop_text_encoder_training": 0, + "train_batch_size": 4, + "train_data_dir": "", + "use_wandb": false, + "v2": false, + "v_parameterization": false, + "v_pred_like_loss": 0, + "vae": "", + "vae_batch_size": 0, + "wandb_api_key": "", + "wandb_run_name": "", + "weighted_captions": false, + "xformers": "xformers" +} \ No newline at end of file diff --git a/test/config/dreambooth-AdamW8bit-toml.json b/test/config/dreambooth-AdamW8bit-toml.json index 3dfeb3dd9..82344dee7 100644 --- a/test/config/dreambooth-AdamW8bit-toml.json +++ b/test/config/dreambooth-AdamW8bit-toml.json @@ -10,7 +10,7 @@ "caption_extension": "", "clip_skip": 2, "color_aug": false, - "dataset_config": "D:/kohya_ss/test/config/dataset.toml", + "dataset_config": "./test/config/dataset.toml", "enable_bucket": true, "epoch": 1, "flip_aug": false, @@ -19,6 +19,8 @@ "gpu_ids": "", "gradient_accumulation_steps": 1, "gradient_checkpointing": false, + "ip_noise_gamma": 0, + "ip_noise_gamma_random_strength": false, "keep_tokens": "0", "learning_rate": 5e-05, "learning_rate_te": 1e-05, @@ -32,6 +34,8 @@ "lr_scheduler_num_cycles": "", "lr_scheduler_power": "", "lr_warmup": 0, + "main_process_port": 12345, + "masked_loss": false, "max_bucket_reso": 2048, "max_data_loader_n_workers": "0", "max_resolution": "512,512", @@ -50,6 +54,7 @@ "multires_noise_iterations": 0, "no_token_padding": false, "noise_offset": 0.05, + "noise_offset_random_strength": false, "noise_offset_type": "Original", "num_cpu_threads_per_process": 2, "num_machines": 1, @@ -75,6 +80,7 @@ "save_model_as": "safetensors", "save_precision": "fp16", "save_state": false, + "save_state_on_train_end": false, "scale_v_pred_loss_like_noise_pred": false, "sdxl": false, "seed": "1234", diff --git a/test/config/dreambooth-AdamW8bit.json b/test/config/dreambooth-AdamW8bit.json index 2b4ae0187..72dd1cf35 100644 --- a/test/config/dreambooth-AdamW8bit.json +++ b/test/config/dreambooth-AdamW8bit.json @@ -1,8 +1,14 @@ { + "LoRA_type": "Kohya LoCon", + "LyCORIS_preset": "full", "adaptive_noise_scale": 0, "additional_parameters": "", + "block_alphas": "", + "block_dims": "", + "block_lr_zero_threshold": "", "bucket_no_upscale": true, "bucket_reso_steps": 64, + "bypass_mode": false, "cache_latents": true, "cache_latents_to_disk": false, "caption_dropout_every_n_epochs": 0.0, @@ -10,46 +16,66 @@ "caption_extension": "", "clip_skip": 2, "color_aug": false, + "constrain": 0.0, + "conv_alpha": 8, + "conv_block_alphas": "", + "conv_block_dims": "", + "conv_dim": 16, "dataset_config": "", + "debiased_estimation_loss": true, + "decompose_both": false, + "dim_from_weights": false, + "dora_wd": false, + "down_lr_weight": "", "enable_bucket": true, "epoch": 1, + "factor": -1, "flip_aug": false, + "fp8_base": false, "full_bf16": false, "full_fp16": false, "gpu_ids": "", "gradient_accumulation_steps": 1, "gradient_checkpointing": false, + "ip_noise_gamma": 0.1, + "ip_noise_gamma_random_strength": true, "keep_tokens": "0", "learning_rate": 5e-05, - "learning_rate_te": 1e-05, - "learning_rate_te1": 1e-05, - "learning_rate_te2": 1e-05, "log_tracker_config": "", "log_tracker_name": "", "logging_dir": "./test/logs", + "lora_network_weights": "", "lr_scheduler": "constant", "lr_scheduler_args": "", "lr_scheduler_num_cycles": "", "lr_scheduler_power": "", "lr_warmup": 0, + "main_process_port": 0, + "masked_loss": false, "max_bucket_reso": 2048, "max_data_loader_n_workers": "0", + "max_grad_norm": 1, "max_resolution": "512,512", "max_timestep": 1000, "max_token_length": "75", "max_train_epochs": "", "max_train_steps": "", "mem_eff_attn": false, + "mid_lr_weight": "", "min_bucket_reso": 256, "min_snr_gamma": 0, "min_timestep": 0, "mixed_precision": "bf16", "model_list": "runwayml/stable-diffusion-v1-5", + "module_dropout": 0.1, "multi_gpu": false, "multires_noise_discount": 0, "multires_noise_iterations": 0, - "no_token_padding": false, + "network_alpha": 8, + "network_dim": 16, + "network_dropout": 0.1, "noise_offset": 0.05, + "noise_offset_random_strength": true, "noise_offset_type": "Original", "num_cpu_threads_per_process": 2, "num_machines": 1, @@ -62,7 +88,10 @@ "pretrained_model_name_or_path": "runwayml/stable-diffusion-v1-5", "prior_loss_weight": 1.0, "random_crop": false, + "rank_dropout": 0.1, + "rank_dropout_scale": false, "reg_data_dir": "", + "rescaled": false, "resume": "", "sample_every_n_epochs": 0, "sample_every_n_steps": 25, @@ -75,13 +104,27 @@ "save_model_as": "safetensors", "save_precision": "fp16", "save_state": false, + "save_state_on_train_end": false, "scale_v_pred_loss_like_noise_pred": false, + "scale_weight_norms": 1, "sdxl": false, + "sdxl_cache_text_encoder_outputs": false, + "sdxl_no_half_vae": true, "seed": "1234", "shuffle_caption": false, "stop_text_encoder_training": 0, + "text_encoder_lr": 0.0, "train_batch_size": 4, "train_data_dir": "./test/img", + "train_norm": false, + "train_on_input": false, + "training_comment": "", + "unet_lr": 0.0, + "unit": 1, + "up_lr_weight": "", + "use_cp": false, + "use_scalar": false, + "use_tucker": false, "use_wandb": false, "v2": false, "v_parameterization": false, diff --git a/test/config/locon-Adafactor.json b/test/config/locon-Adafactor.json new file mode 100644 index 000000000..457a844d7 --- /dev/null +++ b/test/config/locon-Adafactor.json @@ -0,0 +1,138 @@ +{ + "LoRA_type": "Kohya LoCon", + "LyCORIS_preset": "full", + "adaptive_noise_scale": 0.005, + "additional_parameters": "", + "block_alphas": "", + "block_dims": "", + "block_lr_zero_threshold": "", + "bucket_no_upscale": true, + "bucket_reso_steps": 64, + "bypass_mode": false, + "cache_latents": true, + "cache_latents_to_disk": false, + "caption_dropout_every_n_epochs": 0.0, + "caption_dropout_rate": 0.05, + "caption_extension": "", + "clip_skip": 2, + "color_aug": false, + "constrain": 0.0, + "conv_alpha": 8, + "conv_block_alphas": "", + "conv_block_dims": "", + "conv_dim": 16, + "dataset_config": "", + "debiased_estimation_loss": false, + "decompose_both": false, + "dim_from_weights": false, + "dora_wd": false, + "down_lr_weight": "", + "enable_bucket": true, + "epoch": 8, + "factor": -1, + "flip_aug": false, + "fp8_base": false, + "full_bf16": false, + "full_fp16": false, + "gpu_ids": "", + "gradient_accumulation_steps": 1, + "gradient_checkpointing": false, + "ip_noise_gamma": 0.1, + "ip_noise_gamma_random_strength": true, + "keep_tokens": "0", + "learning_rate": 0.0005, + "log_tracker_config": "", + "log_tracker_name": "", + "logging_dir": "./test/logs", + "lora_network_weights": "", + "lr_scheduler": "constant", + "lr_scheduler_args": "", + "lr_scheduler_num_cycles": "", + "lr_scheduler_power": "", + "lr_warmup": 0, + "main_process_port": 0, + "masked_loss": false, + "max_bucket_reso": 2048, + "max_data_loader_n_workers": "0", + "max_grad_norm": 0, + "max_resolution": "512,512", + "max_timestep": 1000, + "max_token_length": "75", + "max_train_epochs": "", + "max_train_steps": "", + "mem_eff_attn": false, + "mid_lr_weight": "", + "min_bucket_reso": 256, + "min_snr_gamma": 0, + "min_timestep": 0, + "mixed_precision": "bf16", + "model_list": "runwayml/stable-diffusion-v1-5", + "module_dropout": 0.1, + "multi_gpu": false, + "multires_noise_discount": 0, + "multires_noise_iterations": 0, + "network_alpha": 8, + "network_dim": 16, + "network_dropout": 0.1, + "noise_offset": 0.05, + "noise_offset_random_strength": true, + "noise_offset_type": "Original", + "num_cpu_threads_per_process": 2, + "num_machines": 1, + "num_processes": 1, + "optimizer": "Adafactor", + "optimizer_args": "scale_parameter=False relative_step=False warmup_init=False", + "output_dir": "./test/output", + "output_name": "locon-adafactor", + "persistent_data_loader_workers": false, + "pretrained_model_name_or_path": "runwayml/stable-diffusion-v1-5", + "prior_loss_weight": 1.0, + "random_crop": false, + "rank_dropout": 0.1, + "rank_dropout_scale": false, + "reg_data_dir": "", + "rescaled": false, + "resume": "", + "sample_every_n_epochs": 0, + "sample_every_n_steps": 25, + "sample_prompts": "a painting of a gas mask , by darius kawasaki", + "sample_sampler": "euler_a", + "save_every_n_epochs": 1, + "save_every_n_steps": 0, + "save_last_n_steps": 0, + "save_last_n_steps_state": 0, + "save_model_as": "safetensors", + "save_precision": "fp16", + "save_state": false, + "save_state_on_train_end": false, + "scale_v_pred_loss_like_noise_pred": false, + "scale_weight_norms": 1, + "sdxl": false, + "sdxl_cache_text_encoder_outputs": false, + "sdxl_no_half_vae": true, + "seed": "1234", + "shuffle_caption": false, + "stop_text_encoder_training": 0, + "text_encoder_lr": 0.0001, + "train_batch_size": 4, + "train_data_dir": "./test/img", + "train_norm": false, + "train_on_input": false, + "training_comment": "", + "unet_lr": 0.0001, + "unit": 1, + "up_lr_weight": "", + "use_cp": false, + "use_scalar": false, + "use_tucker": false, + "use_wandb": false, + "v2": false, + "v_parameterization": false, + "v_pred_like_loss": 0, + "vae": "", + "vae_batch_size": 0, + "wandb_api_key": "", + "wandb_run_name": "", + "weighted_captions": false, + "xformers": "xformers" +} \ No newline at end of file diff --git a/test/config/locon-AdamW8bit-masked_loss-toml.json b/test/config/locon-AdamW8bit-masked_loss-toml.json new file mode 100644 index 000000000..009fbcfc8 --- /dev/null +++ b/test/config/locon-AdamW8bit-masked_loss-toml.json @@ -0,0 +1,137 @@ +{ + "LoRA_type": "Standard", + "LyCORIS_preset": "full", + "adaptive_noise_scale": 0, + "additional_parameters": "", + "block_alphas": "", + "block_dims": "", + "block_lr_zero_threshold": "", + "bucket_no_upscale": true, + "bucket_reso_steps": 64, + "bypass_mode": false, + "cache_latents": true, + "cache_latents_to_disk": false, + "caption_dropout_every_n_epochs": 0.0, + "caption_dropout_rate": 0.05, + "caption_extension": "", + "clip_skip": 2, + "color_aug": false, + "constrain": 0.0, + "conv_alpha": 1, + "conv_block_alphas": "", + "conv_block_dims": "", + "conv_dim": 1, + "dataset_config": "D:/kohya_ss/test/config/dataset-masked_loss.toml", + "debiased_estimation_loss": false, + "decompose_both": false, + "dim_from_weights": false, + "dora_wd": false, + "down_lr_weight": "", + "enable_bucket": true, + "epoch": 1, + "factor": -1, + "flip_aug": false, + "fp8_base": false, + "full_bf16": false, + "full_fp16": false, + "gpu_ids": "", + "gradient_accumulation_steps": 1, + "gradient_checkpointing": false, + "ip_noise_gamma": 0, + "ip_noise_gamma_random_strength": false, + "keep_tokens": "0", + "learning_rate": 0.0005, + "log_tracker_config": "", + "log_tracker_name": "", + "logging_dir": "./test/logs", + "lora_network_weights": "", + "lr_scheduler": "constant", + "lr_scheduler_args": "", + "lr_scheduler_num_cycles": "", + "lr_scheduler_power": "", + "lr_warmup": 0, + "masked_loss": true, + "max_bucket_reso": 2048, + "max_data_loader_n_workers": "0", + "max_grad_norm": 1, + "max_resolution": "512,512", + "max_timestep": 1000, + "max_token_length": "75", + "max_train_epochs": "", + "max_train_steps": "", + "mem_eff_attn": false, + "mid_lr_weight": "", + "min_bucket_reso": 256, + "min_snr_gamma": 0, + "min_timestep": 0, + "mixed_precision": "bf16", + "model_list": "runwayml/stable-diffusion-v1-5", + "module_dropout": 0, + "multi_gpu": false, + "multires_noise_discount": 0, + "multires_noise_iterations": 0, + "network_alpha": 1, + "network_dim": 8, + "network_dropout": 0, + "noise_offset": 0.05, + "noise_offset_random_strength": false, + "noise_offset_type": "Original", + "num_cpu_threads_per_process": 2, + "num_machines": 1, + "num_processes": 1, + "optimizer": "AdamW8bit", + "optimizer_args": "", + "output_dir": "./test/output", + "output_name": "locon-AdamW8bit-masked_loss-toml", + "persistent_data_loader_workers": false, + "pretrained_model_name_or_path": "runwayml/stable-diffusion-v1-5", + "prior_loss_weight": 1.0, + "random_crop": false, + "rank_dropout": 0, + "rank_dropout_scale": false, + "reg_data_dir": "", + "rescaled": false, + "resume": "", + "sample_every_n_epochs": 0, + "sample_every_n_steps": 25, + "sample_prompts": "a painting of a gas mask , by darius kawasaki", + "sample_sampler": "euler_a", + "save_every_n_epochs": 1, + "save_every_n_steps": 0, + "save_last_n_steps": 0, + "save_last_n_steps_state": 0, + "save_model_as": "safetensors", + "save_precision": "fp16", + "save_state": false, + "save_state_on_train_end": false, + "scale_v_pred_loss_like_noise_pred": false, + "scale_weight_norms": 0, + "sdxl": false, + "sdxl_cache_text_encoder_outputs": false, + "sdxl_no_half_vae": true, + "seed": "1234", + "shuffle_caption": false, + "stop_text_encoder_training": 0, + "text_encoder_lr": 0.0, + "train_batch_size": 4, + "train_data_dir": "", + "train_norm": false, + "train_on_input": true, + "training_comment": "", + "unet_lr": 0.0, + "unit": 1, + "up_lr_weight": "", + "use_cp": false, + "use_scalar": false, + "use_tucker": false, + "use_wandb": false, + "v2": false, + "v_parameterization": false, + "v_pred_like_loss": 0, + "vae": "", + "vae_batch_size": 0, + "wandb_api_key": "", + "wandb_run_name": "", + "weighted_captions": false, + "xformers": "xformers" +} \ No newline at end of file diff --git a/test/config/locon-AdamW8bit.json b/test/config/locon-AdamW8bit.json index 99f2947b5..867fff738 100644 --- a/test/config/locon-AdamW8bit.json +++ b/test/config/locon-AdamW8bit.json @@ -1,13 +1,14 @@ { "LoRA_type": "Kohya LoCon", "LyCORIS_preset": "full", - "adaptive_noise_scale": 0, + "adaptive_noise_scale": 0.005, "additional_parameters": "", "block_alphas": "", "block_dims": "", "block_lr_zero_threshold": "", "bucket_no_upscale": true, "bucket_reso_steps": 64, + "bypass_mode": false, "cache_latents": true, "cache_latents_to_disk": false, "caption_dropout_every_n_epochs": 0.0, @@ -20,12 +21,14 @@ "conv_block_alphas": "", "conv_block_dims": "", "conv_dim": 16, - "debiased_estimation_loss": false, + "dataset_config": "", + "debiased_estimation_loss": true, "decompose_both": false, "dim_from_weights": false, + "dora_wd": false, "down_lr_weight": "", "enable_bucket": true, - "epoch": 1, + "epoch": 8, "factor": -1, "flip_aug": false, "fp8_base": false, @@ -34,6 +37,8 @@ "gpu_ids": "", "gradient_accumulation_steps": 1, "gradient_checkpointing": false, + "ip_noise_gamma": 0.1, + "ip_noise_gamma_random_strength": true, "keep_tokens": "0", "learning_rate": 0.0005, "log_tracker_config": "", @@ -45,6 +50,8 @@ "lr_scheduler_num_cycles": "", "lr_scheduler_power": "", "lr_warmup": 0, + "main_process_port": 0, + "masked_loss": false, "max_bucket_reso": 2048, "max_data_loader_n_workers": "0", "max_grad_norm": 1, @@ -68,6 +75,7 @@ "network_dim": 16, "network_dropout": 0.1, "noise_offset": 0.05, + "noise_offset_random_strength": true, "noise_offset_type": "Original", "num_cpu_threads_per_process": 2, "num_machines": 1, @@ -96,6 +104,7 @@ "save_model_as": "safetensors", "save_precision": "fp16", "save_state": false, + "save_state_on_train_end": false, "scale_v_pred_loss_like_noise_pred": false, "scale_weight_norms": 1, "sdxl": false, diff --git a/test/masked_loss/Dariusz_Zawadzki.jpg b/test/masked_loss/Dariusz_Zawadzki.jpg new file mode 100644 index 000000000..4358e6b9e Binary files /dev/null and b/test/masked_loss/Dariusz_Zawadzki.jpg differ diff --git a/test/masked_loss/Dariusz_Zawadzki_2.jpg b/test/masked_loss/Dariusz_Zawadzki_2.jpg new file mode 100644 index 000000000..cf5c489b3 Binary files /dev/null and b/test/masked_loss/Dariusz_Zawadzki_2.jpg differ diff --git a/test/masked_loss/Dariusz_Zawadzki_3.jpg b/test/masked_loss/Dariusz_Zawadzki_3.jpg new file mode 100644 index 000000000..ef89411c6 Binary files /dev/null and b/test/masked_loss/Dariusz_Zawadzki_3.jpg differ diff --git a/test/masked_loss/Dariusz_Zawadzki_4.jpg b/test/masked_loss/Dariusz_Zawadzki_4.jpg new file mode 100644 index 000000000..438602441 Binary files /dev/null and b/test/masked_loss/Dariusz_Zawadzki_4.jpg differ diff --git a/test/masked_loss/Dariusz_Zawadzki_5.jpg b/test/masked_loss/Dariusz_Zawadzki_5.jpg new file mode 100644 index 000000000..2b64b0da4 Binary files /dev/null and b/test/masked_loss/Dariusz_Zawadzki_5.jpg differ diff --git a/test/masked_loss/Dariusz_Zawadzki_6.jpg b/test/masked_loss/Dariusz_Zawadzki_6.jpg new file mode 100644 index 000000000..50d3f6a6d Binary files /dev/null and b/test/masked_loss/Dariusz_Zawadzki_6.jpg differ diff --git a/test/masked_loss/Dariusz_Zawadzki_7.jpg b/test/masked_loss/Dariusz_Zawadzki_7.jpg new file mode 100644 index 000000000..f70fe2a64 Binary files /dev/null and b/test/masked_loss/Dariusz_Zawadzki_7.jpg differ diff --git a/test/masked_loss/Dariusz_Zawadzki_8.jpg b/test/masked_loss/Dariusz_Zawadzki_8.jpg new file mode 100644 index 000000000..3f4507efe Binary files /dev/null and b/test/masked_loss/Dariusz_Zawadzki_8.jpg differ diff --git a/tools/caption_from_filename.py b/tools/caption_from_filename.py new file mode 100644 index 000000000..e579edcaf --- /dev/null +++ b/tools/caption_from_filename.py @@ -0,0 +1,50 @@ +# Proposed by https://github.com/kainatquaderee +import os +import argparse + +def main(image_directory, output_directory, image_extension, text_extension): + # Ensure the output directory exists, create it if necessary + os.makedirs(output_directory, exist_ok=True) + + # Initialize a counter for the number of text files created + text_files_created = 0 + + # Iterate through files in the directory + for image_filename in os.listdir(image_directory): + # Check if the file is an image + if any(image_filename.lower().endswith(ext) for ext in image_extension): + # Extract prompt from filename + prompt = os.path.splitext(image_filename)[0] + + # Construct path for the output text file + text_file_path = os.path.join(output_directory, prompt + text_extension) + + # Write prompt to text file + with open(text_file_path, 'w') as text_file: + text_file.write(prompt) + + print(f"Text file saved: {text_file_path}") + + # Increment the counter + text_files_created += 1 + + # Report if no text files were created + if text_files_created == 0: + print("No image matching extensions were found in the specified directory. No caption files were created.") + else: + print(f"{text_files_created} text files created successfully.") + +if __name__ == "__main__": + # Create an argument parser + parser = argparse.ArgumentParser(description='Generate caption files from image filenames.') + + # Add arguments for the image directory, output directory, and file extension + parser.add_argument('image_directory', help='Directory containing the image files') + parser.add_argument('output_directory', help='Output directory where text files will be saved') + parser.add_argument('--image_extension', nargs='+', default=['.jpg', '.jpeg', '.png', '.webp', '.bmp'], help='Extension for the image files') + parser.add_argument('--text_extension', default='.txt', help='Extension for the output text files') + + # Parse the command-line arguments + args = parser.parse_args() + + main(args.image_directory, args.output_directory, args.image_extension, args.text_extension) diff --git a/v2_inference/v2-inference-v.yaml b/v2_inference/v2-inference-v.yaml deleted file mode 100644 index 8ec8dfbfe..000000000 --- a/v2_inference/v2-inference-v.yaml +++ /dev/null @@ -1,68 +0,0 @@ -model: - base_learning_rate: 1.0e-4 - target: ldm.models.diffusion.ddpm.LatentDiffusion - params: - parameterization: "v" - linear_start: 0.00085 - linear_end: 0.0120 - num_timesteps_cond: 1 - log_every_t: 200 - timesteps: 1000 - first_stage_key: "jpg" - cond_stage_key: "txt" - image_size: 64 - channels: 4 - cond_stage_trainable: false - conditioning_key: crossattn - monitor: val/loss_simple_ema - scale_factor: 0.18215 - use_ema: False # we set this to false because this is an inference only config - - unet_config: - target: ldm.modules.diffusionmodules.openaimodel.UNetModel - params: - use_checkpoint: True - use_fp16: True - image_size: 32 # unused - in_channels: 4 - out_channels: 4 - model_channels: 320 - attention_resolutions: [ 4, 2, 1 ] - num_res_blocks: 2 - channel_mult: [ 1, 2, 4, 4 ] - num_head_channels: 64 # need to fix for flash-attn - use_spatial_transformer: True - use_linear_in_transformer: True - transformer_depth: 1 - context_dim: 1024 - legacy: False - - first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - #attn_type: "vanilla-xformers" - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - cond_stage_config: - target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder - params: - freeze: True - layer: "penultimate" diff --git a/v2_inference/v2-inference.yaml b/v2_inference/v2-inference.yaml deleted file mode 100644 index 152c4f3c2..000000000 --- a/v2_inference/v2-inference.yaml +++ /dev/null @@ -1,67 +0,0 @@ -model: - base_learning_rate: 1.0e-4 - target: ldm.models.diffusion.ddpm.LatentDiffusion - params: - linear_start: 0.00085 - linear_end: 0.0120 - num_timesteps_cond: 1 - log_every_t: 200 - timesteps: 1000 - first_stage_key: "jpg" - cond_stage_key: "txt" - image_size: 64 - channels: 4 - cond_stage_trainable: false - conditioning_key: crossattn - monitor: val/loss_simple_ema - scale_factor: 0.18215 - use_ema: False # we set this to false because this is an inference only config - - unet_config: - target: ldm.modules.diffusionmodules.openaimodel.UNetModel - params: - use_checkpoint: True - use_fp16: True - image_size: 32 # unused - in_channels: 4 - out_channels: 4 - model_channels: 320 - attention_resolutions: [ 4, 2, 1 ] - num_res_blocks: 2 - channel_mult: [ 1, 2, 4, 4 ] - num_head_channels: 64 # need to fix for flash-attn - use_spatial_transformer: True - use_linear_in_transformer: True - transformer_depth: 1 - context_dim: 1024 - legacy: False - - first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - #attn_type: "vanilla-xformers" - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - cond_stage_config: - target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder - params: - freeze: True - layer: "penultimate" diff --git a/v2_inference/v2-inpainting-inference.yaml b/v2_inference/v2-inpainting-inference.yaml deleted file mode 100644 index 32a9471d7..000000000 --- a/v2_inference/v2-inpainting-inference.yaml +++ /dev/null @@ -1,158 +0,0 @@ -model: - base_learning_rate: 5.0e-05 - target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion - params: - linear_start: 0.00085 - linear_end: 0.0120 - num_timesteps_cond: 1 - log_every_t: 200 - timesteps: 1000 - first_stage_key: "jpg" - cond_stage_key: "txt" - image_size: 64 - channels: 4 - cond_stage_trainable: false - conditioning_key: hybrid - scale_factor: 0.18215 - monitor: val/loss_simple_ema - finetune_keys: null - use_ema: False - - unet_config: - target: ldm.modules.diffusionmodules.openaimodel.UNetModel - params: - use_checkpoint: True - image_size: 32 # unused - in_channels: 9 - out_channels: 4 - model_channels: 320 - attention_resolutions: [ 4, 2, 1 ] - num_res_blocks: 2 - channel_mult: [ 1, 2, 4, 4 ] - num_head_channels: 64 # need to fix for flash-attn - use_spatial_transformer: True - use_linear_in_transformer: True - transformer_depth: 1 - context_dim: 1024 - legacy: False - - first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - #attn_type: "vanilla-xformers" - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - attn_resolutions: [ ] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - cond_stage_config: - target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder - params: - freeze: True - layer: "penultimate" - - -data: - target: ldm.data.laion.WebDataModuleFromConfig - params: - tar_base: null # for concat as in LAION-A - p_unsafe_threshold: 0.1 - filter_word_list: "data/filters.yaml" - max_pwatermark: 0.45 - batch_size: 8 - num_workers: 6 - multinode: True - min_size: 512 - train: - shards: - - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-0/{00000..18699}.tar -" - - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-1/{00000..18699}.tar -" - - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-2/{00000..18699}.tar -" - - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-3/{00000..18699}.tar -" - - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-4/{00000..18699}.tar -" #{00000-94333}.tar" - shuffle: 10000 - image_key: jpg - image_transforms: - - target: torchvision.transforms.Resize - params: - size: 512 - interpolation: 3 - - target: torchvision.transforms.RandomCrop - params: - size: 512 - postprocess: - target: ldm.data.laion.AddMask - params: - mode: "512train-large" - p_drop: 0.25 - # NOTE use enough shards to avoid empty validation loops in workers - validation: - shards: - - "pipe:aws s3 cp s3://deep-floyd-s3/datasets/laion_cleaned-part5/{93001..94333}.tar - " - shuffle: 0 - image_key: jpg - image_transforms: - - target: torchvision.transforms.Resize - params: - size: 512 - interpolation: 3 - - target: torchvision.transforms.CenterCrop - params: - size: 512 - postprocess: - target: ldm.data.laion.AddMask - params: - mode: "512train-large" - p_drop: 0.25 - -lightning: - find_unused_parameters: True - modelcheckpoint: - params: - every_n_train_steps: 5000 - - callbacks: - metrics_over_trainsteps_checkpoint: - params: - every_n_train_steps: 10000 - - image_logger: - target: main.ImageLogger - params: - enable_autocast: False - disabled: False - batch_frequency: 1000 - max_images: 4 - increase_log_steps: False - log_first_step: False - log_images_kwargs: - use_ema_scope: False - inpaint: False - plot_progressive_rows: False - plot_diffusion_rows: False - N: 4 - unconditional_guidance_scale: 5.0 - unconditional_guidance_label: [""] - ddim_steps: 50 # todo check these out for depth2img, - ddim_eta: 0.0 # todo check these out for depth2img, - - trainer: - benchmark: True - val_check_interval: 5000000 - num_sanity_val_steps: 0 - accumulate_grad_batches: 1 diff --git a/v2_inference/v2-midas-inference.yaml b/v2_inference/v2-midas-inference.yaml deleted file mode 100644 index f20c30f61..000000000 --- a/v2_inference/v2-midas-inference.yaml +++ /dev/null @@ -1,74 +0,0 @@ -model: - base_learning_rate: 5.0e-07 - target: ldm.models.diffusion.ddpm.LatentDepth2ImageDiffusion - params: - linear_start: 0.00085 - linear_end: 0.0120 - num_timesteps_cond: 1 - log_every_t: 200 - timesteps: 1000 - first_stage_key: "jpg" - cond_stage_key: "txt" - image_size: 64 - channels: 4 - cond_stage_trainable: false - conditioning_key: hybrid - scale_factor: 0.18215 - monitor: val/loss_simple_ema - finetune_keys: null - use_ema: False - - depth_stage_config: - target: ldm.modules.midas.api.MiDaSInference - params: - model_type: "dpt_hybrid" - - unet_config: - target: ldm.modules.diffusionmodules.openaimodel.UNetModel - params: - use_checkpoint: True - image_size: 32 # unused - in_channels: 5 - out_channels: 4 - model_channels: 320 - attention_resolutions: [ 4, 2, 1 ] - num_res_blocks: 2 - channel_mult: [ 1, 2, 4, 4 ] - num_head_channels: 64 # need to fix for flash-attn - use_spatial_transformer: True - use_linear_in_transformer: True - transformer_depth: 1 - context_dim: 1024 - legacy: False - - first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - #attn_type: "vanilla-xformers" - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - attn_resolutions: [ ] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - cond_stage_config: - target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder - params: - freeze: True - layer: "penultimate" - - diff --git a/v2_inference/x4-upscaling.yaml b/v2_inference/x4-upscaling.yaml deleted file mode 100644 index 2db0964af..000000000 --- a/v2_inference/x4-upscaling.yaml +++ /dev/null @@ -1,76 +0,0 @@ -model: - base_learning_rate: 1.0e-04 - target: ldm.models.diffusion.ddpm.LatentUpscaleDiffusion - params: - parameterization: "v" - low_scale_key: "lr" - linear_start: 0.0001 - linear_end: 0.02 - num_timesteps_cond: 1 - log_every_t: 200 - timesteps: 1000 - first_stage_key: "jpg" - cond_stage_key: "txt" - image_size: 128 - channels: 4 - cond_stage_trainable: false - conditioning_key: "hybrid-adm" - monitor: val/loss_simple_ema - scale_factor: 0.08333 - use_ema: False - - low_scale_config: - target: ldm.modules.diffusionmodules.upscaling.ImageConcatWithNoiseAugmentation - params: - noise_schedule_config: # image space - linear_start: 0.0001 - linear_end: 0.02 - max_noise_level: 350 - - unet_config: - target: ldm.modules.diffusionmodules.openaimodel.UNetModel - params: - use_checkpoint: True - num_classes: 1000 # timesteps for noise conditioning (here constant, just need one) - image_size: 128 - in_channels: 7 - out_channels: 4 - model_channels: 256 - attention_resolutions: [ 2,4,8] - num_res_blocks: 2 - channel_mult: [ 1, 2, 2, 4] - disable_self_attentions: [True, True, True, False] - disable_middle_self_attn: False - num_heads: 8 - use_spatial_transformer: True - transformer_depth: 1 - context_dim: 1024 - legacy: False - use_linear_in_transformer: True - - first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - ddconfig: - # attn_type: "vanilla-xformers" this model needs efficient attention to be feasible on HR data, also the decoder seems to break in half precision (UNet is fine though) - double_z: True - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1 - num_res_blocks: 2 - attn_resolutions: [ ] - dropout: 0.0 - - lossconfig: - target: torch.nn.Identity - - cond_stage_config: - target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder - params: - freeze: True - layer: "penultimate" -