Skip to content

Commit

Permalink
Run Ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonrising committed Aug 15, 2024
1 parent 027ac63 commit 0e14fc6
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 17 deletions.
10 changes: 2 additions & 8 deletions invokeai/app/invocations/flux_text_encoder.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
from pathlib import Path

import torch
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
from optimum.quanto import qfloat8
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast

from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.model import CLIPField, T5EncoderField
from invokeai.app.invocations.fields import InputField, FieldDescriptions, Input
from invokeai.app.invocations.flux_text_to_image import FLUX_MODELS, QuantizedModelForTextEncoding
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
from invokeai.app.invocations.model import CLIPField, T5EncoderField
from invokeai.app.invocations.primitives import ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
Expand Down Expand Up @@ -40,15 +35,14 @@ class FluxTextEncoderInvocation(BaseInvocation):
# compatible with other ConditioningOutputs.
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:

t5_embeddings, clip_embeddings = self._encode_prompt(context)
conditioning_data = ConditioningFieldData(
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)]
)

conditioning_name = context.conditioning.save(conditioning_data)
return ConditioningOutput.build(conditioning_name)

def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torch.Tensor]:
# TODO: Determine the T5 max sequence length based on the model.
# if self.model == "flux-schnell":
Expand Down
7 changes: 1 addition & 6 deletions invokeai/app/invocations/flux_text_to_image.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
from pathlib import Path
from typing import Literal
from pydantic import Field

import torch
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
from invokeai.app.invocations.model import TransformerField, VAEField
from optimum.quanto import qfloat8
from PIL import Image
from transformers.models.auto import AutoModelForTextEncoding

Expand All @@ -19,8 +15,8 @@
InputField,
WithBoard,
WithMetadata,
UIType,
)
from invokeai.app.invocations.model import TransformerField, VAEField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.quantization.fast_quantized_diffusion_model import FastQuantizedDiffusersModel
Expand Down Expand Up @@ -72,7 +68,6 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):

@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:

# Load the conditioning data.
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
assert len(cond_data.conditionings) == 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Union

from diffusers.models.model_loading_utils import load_state_dict
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from diffusers.utils import (
CONFIG_NAME,
SAFE_WEIGHTS_INDEX_NAME,
Expand All @@ -12,7 +13,6 @@
)
from optimum.quanto.models import QuantizedDiffusersModel
from optimum.quanto.models.shared_dict import ShardedStateDict
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel

from invokeai.backend.requantize import requantize

Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import json
import os
import torch
from typing import Union

from optimum.quanto.models import QuantizedTransformersModel
from optimum.quanto.models.shared_dict import ShardedStateDict
from transformers import AutoConfig
from transformers.modeling_utils import get_checkpoint_shard_files, load_state_dict
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, is_accelerate_available
from transformers.models.auto import AutoModelForTextEncoding
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, is_accelerate_available

from invokeai.backend.requantize import requantize

Expand Down

0 comments on commit 0e14fc6

Please sign in to comment.