diff --git a/cog.yaml b/cog.yaml index 3752b5a..5f5c69e 100644 --- a/cog.yaml +++ b/cog.yaml @@ -50,7 +50,7 @@ build: - "numpy==1.26.0" - "shortuuid==1.0.11" - "tokenizers==0.19" - - "wandb==0.15.12" + - "wandb==0.17.8" - "wavedrom==2.0.3.post3" - "Pygments==2.16.1" run: diff --git a/train.py b/train.py index cf6a519..90e2d27 100644 --- a/train.py +++ b/train.py @@ -27,24 +27,60 @@ from toolkit.config import get_config from caption import Captioner +from wandb_client import WeightsAndBiasesClient + +JOB_NAME = "flux_train_replicate" WEIGHTS_PATH = Path("./FLUX.1-dev") INPUT_DIR = Path("input_images") OUTPUT_DIR = Path("output") +JOB_DIR = OUTPUT_DIR / JOB_NAME class CustomSDTrainer(SDTrainer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.seen_samples = set() + self.wandb: WeightsAndBiasesClient | None = None + def hook_train_loop(self, batch): - # TODO: Add W&B logging, etc. - return super().hook_train_loop(batch) + loss_dict = super().hook_train_loop(batch) + if self.wandb: + self.wandb.log_loss(loss_dict, self.step_num) + return loss_dict + + def sample(self, step=None, is_first=False): + super().sample(step=step, is_first=is_first) + output_dir = JOB_DIR / "samples" + all_samples = set([p.name for p in output_dir.glob("*.jpg")]) + new_samples = all_samples - self.seen_samples + if self.wandb: + image_paths = [output_dir / p for p in sorted(new_samples)] + self.wandb.log_samples(image_paths, step) + self.seen_samples = all_samples + + def post_save_hook(self, save_path): + super().post_save_hook(save_path) + # final lora path + lora_path = JOB_DIR / f"{JOB_NAME}.safetensors" + if not lora_path.exists(): + # intermediate saved weights + lora_path = sorted(JOB_DIR.glob("*.safetensors"))[-1] + if self.wandb: + print(f"Saving weights to W&B: {lora_path.name}") + self.wandb.save_weights(lora_path) class CustomJob(BaseJob): - def __init__(self, config: OrderedDict): + def __init__( + self, config: OrderedDict, wandb_client: WeightsAndBiasesClient | None + ): super().__init__(config) self.device = self.get_conf("device", "cpu") self.process_dict = {"custom_sd_trainer": CustomSDTrainer} self.load_processes(self.process_dict) + for process in self.process: + process.wandb = wandb_client def run(self): super().run() @@ -82,7 +118,7 @@ def train( ), steps: int = Input( description="Number of training steps. Recommended range 500-4000", - ge=10, + ge=3, le=6000, default=1000, ), @@ -120,6 +156,36 @@ def train( description="Hugging Face token, if you'd like to upload the trained LoRA to Hugging Face.", default=None, ), + wandb_api_key: Secret = Input( + description="Weights and Biases API key, if you'd like to log training progress to W&B.", + default=None, + ), + wandb_project: str = Input( + description="Weights and Biases project name. Only applicable if wandb_api_key is set.", + default=JOB_NAME, + ), + wandb_run: str = Input( + description="Weights and Biases run name. Only applicable if wandb_api_key is set.", + default=None, + ), + wandb_entity: str = Input( + description="Weights and Biases entity name. Only applicable if wandb_api_key is set.", + default=None, + ), + wandb_sample_interval: int = Input( + description="Step interval for sampling output images that are logged to W&B. Only applicable if wandb_api_key is set.", + default=100, + ge=1, + ), + wandb_sample_prompts: str = Input( + description="Semicolon-separated list of prompts to use when logging samples to W&B. Only applicable if wandb_api_key is set.", + default=None, + ), + wandb_save_interval: int = Input( + description="Step interval for saving intermediate LoRA weights to W&B. Only applicable if wandb_api_key is set.", + default=100, + ge=1, + ), skip_training_and_use_pretrained_hf_lora_url: str = Input( description="If you’d like to skip LoRA training altogether and instead create a Replicate model from a pre-trained LoRA that’s on HuggingFace, use this field with a HuggingFace download URL. For example, https://huggingface.co/fofr/flux-80s-cyberpunk/resolve/main/lora.safetensors.", default=None, @@ -136,6 +202,34 @@ def train( if not input_images: raise ValueError("input_images must be provided") + sample_prompts = [] + if wandb_sample_prompts: + sample_prompts = [p.strip() for p in wandb_sample_prompts.split(";")] + + wandb_client = None + if wandb_api_key: + wandb_config = { + "trigger_word": trigger_word, + "autocaption": autocaption, + "autocaption_prefix": autocaption_prefix, + "autocaption_suffix": autocaption_suffix, + "steps": steps, + "learning_rate": learning_rate, + "batch_size": batch_size, + "resolution": resolution, + "lora_rank": lora_rank, + "caption_dropout_rate": caption_dropout_rate, + "optimizer": optimizer, + } + wandb_client = WeightsAndBiasesClient( + api_key=wandb_api_key.get_secret_value(), + config=wandb_config, + sample_prompts=sample_prompts, + project=wandb_project, + entity=wandb_entity, + name=wandb_run, + ) + download_weights() extract_zip(input_images, INPUT_DIR) @@ -143,7 +237,7 @@ def train( { "job": "custom_job", "config": { - "name": "flux_train_replicate", + "name": JOB_NAME, "process": [ { "type": "custom_sd_trainer", @@ -157,7 +251,9 @@ def train( }, "save": { "dtype": "float16", - "save_every": steps + 1, + "save_every": wandb_save_interval + if wandb_api_key + else steps + 1, "max_step_saves_to_keep": 1, }, "datasets": [ @@ -166,6 +262,7 @@ def train( "caption_ext": "txt", "caption_dropout_rate": caption_dropout_rate, "shuffle_tokens": False, + # TODO: Do we need to cache to disk? It's faster not to. "cache_latents_to_disk": True, "resolution": [ int(res) for res in resolution.split(",") @@ -193,15 +290,17 @@ def train( }, "sample": { "sampler": "flowmatch", - "sample_every": steps + 1, + "sample_every": wandb_sample_interval + if wandb_api_key and sample_prompts + else steps + 1, "width": 1024, "height": 1024, - "prompts": [], + "prompts": sample_prompts, "neg": "", "seed": 42, "walk_seed": True, - "guidance_scale": 4, - "sample_steps": 20, + "guidance_scale": 3.5, + "sample_steps": 28, }, } ], @@ -222,39 +321,52 @@ def train( torch.cuda.empty_cache() print("Starting train job") - job = CustomJob(get_config(train_config, name=None)) + job = CustomJob(get_config(train_config, name=None), wandb_client) job.run() + + if wandb_client: + wandb_client.finish() + job.cleanup() - lora_dir = OUTPUT_DIR / "flux_train_replicate" - lora_file = lora_dir / "flux_train_replicate.safetensors" - lora_file.rename(lora_dir / "lora.safetensors") + lora_file = JOB_DIR / f"{JOB_NAME}.safetensors" + lora_file.rename(JOB_DIR / "lora.safetensors") + + samples_dir = JOB_DIR / "samples" + if samples_dir.exists(): + shutil.rmtree(samples_dir) + + # Remove any intermediate lora paths + lora_paths = JOB_DIR.glob("*.safetensors") + for path in lora_paths: + if path.name != "lora.safetensors": + path.unlink() # Optimizer is used to continue training, not needed in output - optimizer_file = lora_dir / "optimizer.pt" + optimizer_file = JOB_DIR / "optimizer.pt" if optimizer_file.exists(): optimizer_file.unlink() # Copy generated captions to the output tar # But do not upload publicly to HF - captions_dir = lora_dir / "captions" + captions_dir = JOB_DIR / "captions" captions_dir.mkdir(exist_ok=True) for caption_file in INPUT_DIR.glob("*.txt"): shutil.copy(caption_file, captions_dir) - os.system(f"tar -cvf {output_path} {lora_dir}") + os.system(f"tar -cvf {output_path} {JOB_DIR}") if hf_token is not None and hf_repo_id is not None: if captions_dir.exists(): shutil.rmtree(captions_dir) try: - handle_hf_readme(lora_dir, hf_repo_id, trigger_word) + handle_hf_readme(hf_repo_id, trigger_word) print(f"Uploading to Hugging Face: {hf_repo_id}") api = HfApi() api.upload_folder( repo_id=hf_repo_id, - folder_path=lora_dir, + folder_path=JOB_DIR, repo_type="model", use_auth_token=hf_token.get_secret_value(), ) @@ -264,8 +376,8 @@ def train( return TrainingOutput(weights=Path(output_path)) -def handle_hf_readme(lora_dir: Path, hf_repo_id: str, trigger_word: Optional[str]): - readme_path = lora_dir / "README.md" +def handle_hf_readme(hf_repo_id: str, trigger_word: Optional[str]): + readme_path = JOB_DIR / "README.md" license_path = Path("lora-license.md") shutil.copy(license_path, readme_path) diff --git a/wandb_client.py b/wandb_client.py new file mode 100644 index 0000000..36ad087 --- /dev/null +++ b/wandb_client.py @@ -0,0 +1,43 @@ +from pathlib import Path +from typing import Any, Sequence +import wandb +from wandb.sdk.wandb_settings import Settings + + +class WeightsAndBiasesClient: + def __init__( + self, + api_key: str, + project: str, + config: dict, + sample_prompts: list[str], + entity: str | None, + name: str | None, + ): + self.api_key = api_key + self.sample_prompts = sample_prompts + wandb.login(key=self.api_key, verify=True) + self.run = wandb.init( + project=project, + entity=entity, + name=name, + config=config, + save_code=False, + settings=Settings(_disable_machine_info=True), + ) + + def log_loss(self, loss_dict: dict[str, Any], step: int | None): + wandb.log(data=loss_dict, step=step) + + def log_samples(self, image_paths: Sequence[Path], step: int | None): + data = { + f"samples/{prompt}": wandb.Image(str(path)) + for prompt, path in zip(self.sample_prompts, image_paths) + } + wandb.log(data=data, step=step) + + def save_weights(self, lora_path: Path): + wandb.save(lora_path) + + def finish(self): + wandb.finish()