Skip to content

Commit

Permalink
[Experimental] Add cache mechanism for dataset groups to avoid long w…
Browse files Browse the repository at this point in the history
…aiting time for initilization (#1178)

* support meta cached dataset

* add cache meta scripts

* random ip_noise_gamma strength

* random noise_offset strength

* use correct settings for parser

* cache path/caption/size only

* revert mess up commit

* revert mess up commit

* Update requirements.txt

* Add arguments for meta cache.

* remove pickle implementation

* Return sizes when enable cache

---------

Co-authored-by: Kohya S <[email protected]>
  • Loading branch information
KohakuBlueleaf and kohya-ss authored Mar 24, 2024
1 parent 381c449 commit ae97c8b
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 22 deletions.
103 changes: 103 additions & 0 deletions cache_dataset_meta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import argparse
import random

from accelerate.utils import set_seed

import library.train_util as train_util
import library.config_util as config_util
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
import library.custom_train_functions as custom_train_functions
from library.utils import setup_logging, add_logging_arguments

setup_logging()
import logging

logger = logging.getLogger(__name__)


def make_dataset(args):
train_util.prepare_dataset_args(args, True)
setup_logging(args, reset=True)

use_dreambooth_method = args.in_json is None
use_user_config = args.dataset_config is not None

if args.seed is None:
args.seed = random.randint(0, 2**32)
set_seed(args.seed)

# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(
ConfigSanitizer(True, True, False, True)
)
if use_user_config:
logger.info(f"Loading dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
logger.warning(
"ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
if use_dreambooth_method:
logger.info("Using DreamBooth method.")
user_config = {
"datasets": [
{
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
args.train_data_dir, args.reg_data_dir
)
}
]
}
else:
logger.info("Training with captions.")
user_config = {
"datasets": [
{
"subsets": [
{
"image_dir": args.train_data_dir,
"metadata_file": args.in_json,
}
]
}
]
}

blueprint = blueprint_generator.generate(user_config, args, tokenizer=None)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(
blueprint.dataset_group
)
else:
# use arbitrary dataset class
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer=None)
return train_dataset_group


def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, True)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
return parser


if __name__ == "__main__":
parser = setup_parser()

args, unknown = parser.parse_known_args()
args = train_util.read_config_from_file(args, parser)
if args.max_token_length is None:
args.max_token_length = 75
args.cache_meta = True

dataset_group = make_dataset(args)
4 changes: 4 additions & 0 deletions library/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ class DreamBoothDatasetParams(BaseDatasetParams):
bucket_reso_steps: int = 64
bucket_no_upscale: bool = False
prior_loss_weight: float = 1.0
cache_meta: bool = False
use_cached_meta: bool = False


@dataclass
Expand Down Expand Up @@ -228,6 +230,8 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]
"min_bucket_reso": int,
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
"network_multiplier": float,
"cache_meta": bool,
"use_cached_meta": bool,
}

# options handled by argparse but not handled by user config
Expand Down
83 changes: 62 additions & 21 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from huggingface_hub import hf_hub_download
import numpy as np
from PIL import Image
import imagesize
import cv2
import safetensors.torch
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
Expand Down Expand Up @@ -1080,8 +1081,7 @@ def cache_text_encoder_outputs(
)

def get_image_size(self, image_path):
image = Image.open(image_path)
return image.size
return imagesize.get(image_path)

def load_image_with_face_info(self, subset: BaseSubset, image_path: str):
img = load_image(image_path)
Expand Down Expand Up @@ -1425,6 +1425,8 @@ def __init__(
bucket_no_upscale: bool,
prior_loss_weight: float,
debug_dataset: bool,
cache_meta: bool,
use_cached_meta: bool,
) -> None:
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)

Expand Down Expand Up @@ -1484,26 +1486,43 @@ def load_dreambooth_dir(subset: DreamBoothSubset):
logger.warning(f"not directory: {subset.image_dir}")
return [], []

img_paths = glob_images(subset.image_dir, "*")
sizes = None
if use_cached_meta:
logger.info(f"using cached metadata: {subset.image_dir}/dataset.txt")
# [img_path, caption, resolution]
with open(f"{subset.image_dir}/dataset.txt", "r", encoding="utf-8") as f:
metas = f.readlines()
metas = [x.strip().split("<|##|>") for x in metas]
sizes = [tuple(int(res) for res in x[2].split(" ")) for x in metas]

if use_cached_meta:
img_paths = [x[0] for x in metas]
else:
img_paths = glob_images(subset.image_dir, "*")
sizes = [None]*len(img_paths)
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")

# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
captions = []
missing_captions = []
for img_path in img_paths:
cap_for_img = read_caption(img_path, subset.caption_extension, subset.enable_wildcard)
if cap_for_img is None and subset.class_tokens is None:
logger.warning(
f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}"
)
captions.append("")
missing_captions.append(img_path)
else:
if cap_for_img is None:
captions.append(subset.class_tokens)
if use_cached_meta:
captions = [x[1] for x in metas]
missing_captions = [x[0] for x in metas if x[1] == ""]
else:
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
captions = []
missing_captions = []
for img_path in img_paths:
cap_for_img = read_caption(img_path, subset.caption_extension, subset.enable_wildcard)
if cap_for_img is None and subset.class_tokens is None:
logger.warning(
f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}"
)
captions.append("")
missing_captions.append(img_path)
else:
captions.append(cap_for_img)
if cap_for_img is None:
captions.append(subset.class_tokens)
missing_captions.append(img_path)
else:
captions.append(cap_for_img)

self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録

Expand All @@ -1520,7 +1539,21 @@ def load_dreambooth_dir(subset: DreamBoothSubset):
logger.warning(missing_caption + f"... and {remaining_missing_captions} more")
break
logger.warning(missing_caption)
return img_paths, captions

if cache_meta:
logger.info(f"cache metadata for {subset.image_dir}")
if sizes is None or sizes[0] is None:
sizes = [self.get_image_size(img_path) for img_path in img_paths]
# [img_path, caption, resolution]
data = [
(img_path, caption, " ".join(str(x) for x in size))
for img_path, caption, size in zip(img_paths, captions, sizes)
]
with open(f"{subset.image_dir}/dataset.txt", "w", encoding="utf-8") as f:
f.write("\n".join(["<|##|>".join(x) for x in data]))
logger.info(f"cache metadata done for {subset.image_dir}")

return img_paths, captions, sizes

logger.info("prepare images.")
num_train_images = 0
Expand All @@ -1539,7 +1572,7 @@ def load_dreambooth_dir(subset: DreamBoothSubset):
)
continue

img_paths, captions = load_dreambooth_dir(subset)
img_paths, captions, sizes = load_dreambooth_dir(subset)
if len(img_paths) < 1:
logger.warning(
f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します"
Expand All @@ -1551,8 +1584,10 @@ def load_dreambooth_dir(subset: DreamBoothSubset):
else:
num_train_images += subset.num_repeats * len(img_paths)

for img_path, caption in zip(img_paths, captions):
for img_path, caption, size in zip(img_paths, captions, sizes):
info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path)
if size is not None:
info.image_size = size
if subset.is_reg:
reg_infos.append((info, subset))
else:
Expand Down Expand Up @@ -3355,6 +3390,12 @@ def add_dataset_arguments(
parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool
):
# dataset common
parser.add_argument(
"--cache_meta", action="store_true"
)
parser.add_argument(
"--use_cached_meta", action="store_true"
)
parser.add_argument(
"--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ"
)
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ easygui==0.98.3
toml==0.10.2
voluptuous==0.13.1
huggingface-hub==0.20.1
# for Image utils
imagesize==1.4.1
# for BLIP captioning
# requests==2.28.2
# timm==0.6.12
Expand Down
3 changes: 2 additions & 1 deletion train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import random
import time
import json
import pickle
from multiprocessing import Value
import toml

Expand All @@ -23,7 +24,7 @@

import library.train_util as train_util
from library.train_util import (
DreamBoothDataset,
DreamBoothDataset, DatasetGroup
)
import library.config_util as config_util
from library.config_util import (
Expand Down

0 comments on commit ae97c8b

Please sign in to comment.