Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decent settings for Booru tags #12

Open
aimerib opened this issue Dec 26, 2024 · 2 comments
Open

Decent settings for Booru tags #12

aimerib opened this issue Dec 26, 2024 · 2 comments

Comments

@aimerib
Copy link

aimerib commented Dec 26, 2024

I would have done this as a discussion topic instead of an issue as I don't have any. Simply sharing some findings here:

As your readme points out, Booru tags are currently really unstable and have a tendency to repeat a lot. While I agree with your overall assessment that this is a training problem, I've found that I could get somewhat decent results (not great, just slightly more consistent) with the following settings:
Screenshot 2024-12-26 at 2 36 53 PM

(for anyone seeing this, please note that CFG = 1 is the same things as CFG off, so right now that slider is not doing anything. I added it just for testing)

Again, this is not a magic cure, but it did help a lot with repetition, to the point where I would feel comfortable integrating the current model into my captioning workflow today considering it helps save some time looking up tags and cross-referencing stuff.

if anyone is curious, I modified the gradio space slightly to support these settings:

import spaces
import gradio as gr
from transformers import (
    AutoTokenizer,
    PreTrainedTokenizer,
    PreTrainedTokenizerFast,
    LlavaForConditionalGeneration,
    TextIteratorStreamer,
    GenerationConfig
)
import torch
import torch.amp.autocast_mode
from PIL import Image
import torchvision.transforms.functional as TVF
from threading import Thread
from typing import Generator


MODEL_PATH = "fancyfeast/llama-joycaption-alpha-two-vqa-test-1"
TITLE = "<h1><center>JoyCaption Alpha Two - VQA Test - (2024-11-25a)</center></h1>"
DESCRIPTION = """
...
"""

PLACEHOLDER = """
"""


# Load model
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=True)
assert isinstance(tokenizer, PreTrainedTokenizer) or isinstance(
    tokenizer, PreTrainedTokenizerFast
), f"Expected PreTrainedTokenizer, got {type(tokenizer)}"

model = LlavaForConditionalGeneration.from_pretrained(
    MODEL_PATH, torch_dtype="bfloat16", device_map=0
)
assert isinstance(
    model, LlavaForConditionalGeneration
), f"Expected LlavaForConditionalGeneration, got {type(model)}"


def trim_off_prompt(input_ids: list[int], eoh_id: int, eot_id: int) -> list[int]:
    # Trim off the prompt
    while True:
        try:
            i = input_ids.index(eoh_id)
        except ValueError:
            break

        input_ids = input_ids[i + 1 :]

    # Trim off the end
    try:
        i = input_ids.index(eot_id)
    except ValueError:
        return input_ids

    return input_ids[:i]


end_of_header_id = tokenizer.convert_tokens_to_ids("<|end_header_id|>")
end_of_turn_id = tokenizer.convert_tokens_to_ids("<|eot_id|>")
assert isinstance(end_of_header_id, int) and isinstance(end_of_turn_id, int)


@spaces.GPU()
@torch.no_grad()
def chat_joycaption(
    message: dict,
    history,
    temperature: float,
    top_p: float,
    top_k: float,
    min_p: float,
    repeat_penalty: float,
    cfg_scale: float,
    max_new_tokens: int,
) -> Generator[str, None, None]:
    torch.cuda.empty_cache()

    chat_interface.chatbot_state

    # Prompts are always stripped in training for now
    prompt = message["text"].strip()

    # Load image
    if "files" not in message or len(message["files"]) != 1:
        yield "ERROR: This model requires exactly one image as input."
        return

    image = Image.open(message["files"][0])

    # Preprocess image
    # NOTE: I found the default processor for so400M to have worse results than just using PIL directly
    if image.size != (384, 384):
        image = image.resize((384, 384), Image.LANCZOS)
    image = image.convert("RGB")
    pixel_values = TVF.pil_to_tensor(image)

    convo = [
        {
            "role": "system",
            "content": "You are a helpful image captioner.",
        },
        {
            "role": "user",
            "content": prompt,
        },
    ]

    # Format the conversation
    convo_string = tokenizer.apply_chat_template(
        convo, tokenize=False, add_generation_prompt=True
    )
    assert isinstance(convo_string, str)

    # Tokenize the conversation
    convo_tokens = tokenizer.encode(
        convo_string, add_special_tokens=False, truncation=False
    )

    # Repeat the image tokens
    input_tokens = []
    for token in convo_tokens:
        if token == model.config.image_token_index:
            input_tokens.extend(
                [model.config.image_token_index] * model.config.image_seq_length
            )
        else:
            input_tokens.append(token)

    input_ids = torch.tensor(input_tokens, dtype=torch.long)
    attention_mask = torch.ones_like(input_ids)

    # Move to GPU
    input_ids = input_ids.unsqueeze(0).to("cuda")
    attention_mask = attention_mask.unsqueeze(0).to("cuda")
    pixel_values = pixel_values.unsqueeze(0).to("cuda")

    # Normalize the image
    pixel_values = pixel_values / 255.0
    pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
    pixel_values = pixel_values.to(torch.bfloat16)

    streamer = TextIteratorStreamer(
        tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
    )

    generate_kwargs = dict(
        input_ids=input_ids,
        pixel_values=pixel_values,
        attention_mask=attention_mask,
        suppress_tokens=None,
        use_cache=True,
        streamer=streamer,
        generation_config=GenerationConfig(
            temperature=temperature,
            top_k=top_k,
            min_p=min_p,
            repeat_penalty=repeat_penalty,
            top_p=top_p,
            guidance_scale=cfg_scale,
            max_new_tokens=max_new_tokens,
            do_sample=True
        )
    )

    

    if temperature == 0:
        generate_kwargs["generation_config"]["do_sample"] = False

    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)


chatbot = gr.Chatbot(
    height=450, placeholder=PLACEHOLDER, label="Gradio ChatInterface", type="messages"
)
textbox = gr.MultimodalTextbox(file_types=["image"], file_count="single")

with gr.Blocks() as demo:
    gr.HTML(TITLE)
    chat_interface = gr.ChatInterface(
        fn=chat_joycaption,
        chatbot=chatbot,
        type="messages",
        fill_height=True,
        multimodal=True,
        textbox=textbox,
        additional_inputs_accordion=gr.Accordion(
            label="⚙️ Parameters", open=True, render=False
        ),
        additional_inputs=[
            gr.Slider(
                minimum=0,
                maximum=1,
                step=0.1,
                value=0.6,
                label="Temperature",
                render=False,
            ),
            gr.Slider(
                minimum=0, maximum=1, step=0.05, value=0.9, label="Top p", render=False
            ),
            gr.Slider(minimum=0, maximum=200, value=50.0, label="Top k", render=False),
            gr.Slider(minimum=0, maximum=2, value=0.1, label="Min p", render=False),
            gr.Slider(
                minimum=0, maximum=4, value=1.2, label="Repeat penalty", render=False
            ),
            gr.Slider(minimum=1, maximum=50, step=0.5, value=1, label="CFG", render=False),
            gr.Slider(
                minimum=8,
                maximum=4096,
                step=1,
                value=1024,
                label="Max new tokens",
                render=False,
            ),
        ],
    )

    gr.Markdown(DESCRIPTION)


if __name__ == "__main__":
    demo.launch()
@fpgaminer
Copy link
Owner

Thank you for sharing! Yeah repetition penalty is likely to help a lot with the issue. That said, if you want tags for an image WD Tagger and JoyTag will work more accurately and consistently. For now the main goal of baking a booru tag mode into JoyCaption is mostly about teaching it about those tags.

@aimerib
Copy link
Author

aimerib commented Jan 2, 2025

100% I am in search of something that can do both, tag and nl. My ultimate goal is to compile a library of really solid prompts across SDXL model families (SDXL base finetunes, pony finetunes, and illustrous finetunes), to ultimately train a model to help with zero-shot prompting for t2i, so I've been scouring civitai for metadata, but what I've found is that a lot of the prompts work either by miracle or by LoRA. I quickly realized that the fastest way to reach a good dataset would have to be using the visual elements of an image and turn that into a prompt. Joycaption is excellent for this task. The next step is to then run the output of joycaption into something like TIPO (https://huggingface.co/spaces/KBlueLeaf/TIPO-DEMO - no inferencing code yet, but the model they use in their demos is llama based, so I should be able to figure something out) for prompt expansion, and then refine from there (although TIPO is really good at extending the prompt)
This workflow really wouldn't be possible without something like joycaption! Other models tend to lack in creativity.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants