Skip to content

Commit

Permalink
Run Ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonrising committed Aug 15, 2024
1 parent b5f35ed commit 027ac63
Show file tree
Hide file tree
Showing 6 changed files with 8 additions and 11 deletions.
5 changes: 1 addition & 4 deletions invokeai/app/invocations/flux_text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,7 @@ def _run_diffusion(
# if the cache is not empty.
# context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)

with (
transformer_info as transformer,
scheduler_info as scheduler
):
with transformer_info as transformer, scheduler_info as scheduler:
assert isinstance(transformer, FluxTransformer2DModel)
assert isinstance(scheduler, FlowMatchEulerDiscreteScheduler)

Expand Down
2 changes: 1 addition & 1 deletion invokeai/app/invocations/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,11 @@ class CLIPField(BaseModel):
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")



class TransformerField(BaseModel):
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
scheduler: ModelIdentifierField = Field(description="Info to load scheduler submodel")


class T5EncoderField(BaseModel):
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
Expand Down
4 changes: 1 addition & 3 deletions invokeai/backend/model_manager/load/model_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,7 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
return model.calc_size()
elif isinstance(
model,
(
T5TokenizerFast,
),
(T5TokenizerFast,),
):
return len(model)
else:
Expand Down
2 changes: 1 addition & 1 deletion invokeai/backend/model_manager/util/select_hf_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def filter_files(
"lora_weights.safetensors",
"weights.pb",
"onnx_data",
"spiece.model", # Added for `black-forest-labs/FLUX.1-schnell`.
"spiece.model", # Added for `black-forest-labs/FLUX.1-schnell`.
)
):
paths.append(file)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

class FastQuantizedDiffusersModel(QuantizedDiffusersModel):
@classmethod
def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike], base_class = FluxTransformer2DModel, **kwargs):
def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike], base_class=FluxTransformer2DModel, **kwargs):
"""We override the `from_pretrained()` method in order to use our custom `requantize()` implementation."""
base_class = base_class or cls.base_class
if base_class is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@

class FastQuantizedTransformersModel(QuantizedTransformersModel):
@classmethod
def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike], auto_class = AutoModelForTextEncoding, **kwargs):
def from_pretrained(
cls, model_name_or_path: Union[str, os.PathLike], auto_class=AutoModelForTextEncoding, **kwargs
):
"""We override the `from_pretrained()` method in order to use our custom `requantize()` implementation."""
auto_class = auto_class or cls.auto_class
if auto_class is None:
Expand Down

0 comments on commit 027ac63

Please sign in to comment.