Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Experimental] Add cache mechanism for dataset groups to avoid long waiting time for initilization #1178

Merged
merged 13 commits into from
Mar 24, 2024
103 changes: 103 additions & 0 deletions cache_dataset_meta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import argparse
import random

from accelerate.utils import set_seed

import library.train_util as train_util
import library.config_util as config_util
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
import library.custom_train_functions as custom_train_functions
from library.utils import setup_logging, add_logging_arguments

setup_logging()
import logging

logger = logging.getLogger(__name__)


def make_dataset(args):
train_util.prepare_dataset_args(args, True)
setup_logging(args, reset=True)

use_dreambooth_method = args.in_json is None
use_user_config = args.dataset_config is not None

if args.seed is None:
args.seed = random.randint(0, 2**32)
set_seed(args.seed)

# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(
ConfigSanitizer(True, True, False, True)
)
if use_user_config:
logger.info(f"Loading dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
logger.warning(
"ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
if use_dreambooth_method:
logger.info("Using DreamBooth method.")
user_config = {
"datasets": [
{
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
args.train_data_dir, args.reg_data_dir
)
}
]
}
else:
logger.info("Training with captions.")
user_config = {
"datasets": [
{
"subsets": [
{
"image_dir": args.train_data_dir,
"metadata_file": args.in_json,
}
]
}
]
}

blueprint = blueprint_generator.generate(user_config, args, tokenizer=None)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(
blueprint.dataset_group
)
else:
# use arbitrary dataset class
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer=None)
return train_dataset_group


def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, True)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
return parser


if __name__ == "__main__":
parser = setup_parser()

args, unknown = parser.parse_known_args()
args = train_util.read_config_from_file(args, parser)
if args.max_token_length is None:
args.max_token_length = 75
args.cache_meta = True

dataset_group = make_dataset(args)
4 changes: 4 additions & 0 deletions library/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ class DreamBoothDatasetParams(BaseDatasetParams):
bucket_reso_steps: int = 64
bucket_no_upscale: bool = False
prior_loss_weight: float = 1.0
cache_meta: bool = False
use_cached_meta: bool = False


@dataclass
Expand Down Expand Up @@ -228,6 +230,8 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]
"min_bucket_reso": int,
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
"network_multiplier": float,
"cache_meta": bool,
"use_cached_meta": bool,
}

# options handled by argparse but not handled by user config
Expand Down
83 changes: 62 additions & 21 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from huggingface_hub import hf_hub_download
import numpy as np
from PIL import Image
import imagesize
import cv2
import safetensors.torch
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
Expand Down Expand Up @@ -1080,8 +1081,7 @@ def cache_text_encoder_outputs(
)

def get_image_size(self, image_path):
image = Image.open(image_path)
return image.size
return imagesize.get(image_path)

def load_image_with_face_info(self, subset: BaseSubset, image_path: str):
img = load_image(image_path)
Expand Down Expand Up @@ -1425,6 +1425,8 @@ def __init__(
bucket_no_upscale: bool,
prior_loss_weight: float,
debug_dataset: bool,
cache_meta: bool,
use_cached_meta: bool,
) -> None:
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)

Expand Down Expand Up @@ -1484,26 +1486,43 @@ def load_dreambooth_dir(subset: DreamBoothSubset):
logger.warning(f"not directory: {subset.image_dir}")
return [], []

img_paths = glob_images(subset.image_dir, "*")
sizes = None
if use_cached_meta:
logger.info(f"using cached metadata: {subset.image_dir}/dataset.txt")
# [img_path, caption, resolution]
with open(f"{subset.image_dir}/dataset.txt", "r", encoding="utf-8") as f:
metas = f.readlines()
metas = [x.strip().split("<|##|>") for x in metas]
sizes = [tuple(int(res) for res in x[2].split(" ")) for x in metas]

if use_cached_meta:
img_paths = [x[0] for x in metas]
else:
img_paths = glob_images(subset.image_dir, "*")
sizes = [None]*len(img_paths)
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")

# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
captions = []
missing_captions = []
for img_path in img_paths:
cap_for_img = read_caption(img_path, subset.caption_extension, subset.enable_wildcard)
if cap_for_img is None and subset.class_tokens is None:
logger.warning(
f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}"
)
captions.append("")
missing_captions.append(img_path)
else:
if cap_for_img is None:
captions.append(subset.class_tokens)
if use_cached_meta:
captions = [x[1] for x in metas]
missing_captions = [x[0] for x in metas if x[1] == ""]
else:
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
captions = []
missing_captions = []
for img_path in img_paths:
cap_for_img = read_caption(img_path, subset.caption_extension, subset.enable_wildcard)
if cap_for_img is None and subset.class_tokens is None:
logger.warning(
f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}"
)
captions.append("")
missing_captions.append(img_path)
else:
captions.append(cap_for_img)
if cap_for_img is None:
captions.append(subset.class_tokens)
missing_captions.append(img_path)
else:
captions.append(cap_for_img)

self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録

Expand All @@ -1520,7 +1539,21 @@ def load_dreambooth_dir(subset: DreamBoothSubset):
logger.warning(missing_caption + f"... and {remaining_missing_captions} more")
break
logger.warning(missing_caption)
return img_paths, captions

if cache_meta:
logger.info(f"cache metadata for {subset.image_dir}")
if sizes is None or sizes[0] is None:
sizes = [self.get_image_size(img_path) for img_path in img_paths]
# [img_path, caption, resolution]
data = [
(img_path, caption, " ".join(str(x) for x in size))
for img_path, caption, size in zip(img_paths, captions, sizes)
]
with open(f"{subset.image_dir}/dataset.txt", "w", encoding="utf-8") as f:
f.write("\n".join(["<|##|>".join(x) for x in data]))
logger.info(f"cache metadata done for {subset.image_dir}")

return img_paths, captions, sizes

logger.info("prepare images.")
num_train_images = 0
Expand All @@ -1539,7 +1572,7 @@ def load_dreambooth_dir(subset: DreamBoothSubset):
)
continue

img_paths, captions = load_dreambooth_dir(subset)
img_paths, captions, sizes = load_dreambooth_dir(subset)
if len(img_paths) < 1:
logger.warning(
f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します"
Expand All @@ -1551,8 +1584,10 @@ def load_dreambooth_dir(subset: DreamBoothSubset):
else:
num_train_images += subset.num_repeats * len(img_paths)

for img_path, caption in zip(img_paths, captions):
for img_path, caption, size in zip(img_paths, captions, sizes):
info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path)
if size is not None:
info.image_size = size
if subset.is_reg:
reg_infos.append((info, subset))
else:
Expand Down Expand Up @@ -3355,6 +3390,12 @@ def add_dataset_arguments(
parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool
):
# dataset common
parser.add_argument(
"--cache_meta", action="store_true"
)
parser.add_argument(
"--use_cached_meta", action="store_true"
)
parser.add_argument(
"--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ"
)
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ easygui==0.98.3
toml==0.10.2
voluptuous==0.13.1
huggingface-hub==0.20.1
# for Image utils
imagesize==1.4.1
# for BLIP captioning
# requests==2.28.2
# timm==0.6.12
Expand Down
3 changes: 2 additions & 1 deletion train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import random
import time
import json
import pickle
from multiprocessing import Value
import toml

Expand All @@ -23,7 +24,7 @@

import library.train_util as train_util
from library.train_util import (
DreamBoothDataset,
DreamBoothDataset, DatasetGroup
)
import library.config_util as config_util
from library.config_util import (
Expand Down