Skip to content

Commit

Permalink
use optimum from branch
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Feb 4, 2025
1 parent 31507cc commit 40ef93f
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 61 deletions.
106 changes: 46 additions & 60 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def init_model_configs():
# for model registration in auto transformers classses
if importlib.util.find_spec("janus") is not None:
try:
from janus.models import MultiModalityCausalLM, VLChatProcessor
from janus.models import MultiModalityCausalLM # noqa: F401
except ImportError:
pass

Expand Down Expand Up @@ -1352,9 +1352,7 @@ def patch_model_for_export(


class LMInputEmbedsConfigHelper(TextDecoderWithPositionIdsOnnxConfig):
def __init__(
self, export_config, patcher_cls=None, dummy_input_generator=None, inputs_update=None, remove_lm_head=False
):
def __init__(self, export_config, patcher_cls=None, dummy_input_generator=None, inputs_update=None, remove_lm_head=False):
self.orig_export_config = export_config
if dummy_input_generator is not None:
export_config.DUMMY_INPUT_GENERATOR_CLASSES = (
Expand All @@ -1373,15 +1371,16 @@ def __init__(
def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":

if self.patcher_cls is not None:
patcher = self.patcher_cls(self, model, model_kwargs=model_kwargs)
# Refer to DecoderModelPatcher.
else:
else:
patcher = self.orig_export_config.patch_model_for_export(model, model_kwargs=model_kwargs)

if self.remove_lm_head:
patcher = RemoveLMHeadPatcherHelper(self, model, model_kwargs, patcher)

return patcher

@property
Expand All @@ -1390,7 +1389,7 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
if self.remove_lm_head:
logits_info = outputs.pop("logits")
updated_outputs = {"last_hidden_state": logits_info}
return {**updated_outputs, **outputs}
return {**updated_outputs, **outputs}
return outputs

@property
Expand Down Expand Up @@ -1479,15 +1478,15 @@ def get_vlm_text_generation_config(
model_patcher=None,
dummy_input_generator=None,
inputs_update=None,
remove_lm_head=False,
remove_lm_head=False
):
internal_export_config = get_vlm_internal_text_generation_config(model_type, model_config, int_dtype, float_dtype)
export_config = LMInputEmbedsConfigHelper(
internal_export_config,
patcher_cls=model_patcher,
dummy_input_generator=dummy_input_generator,
inputs_update=inputs_update,
remove_lm_head=remove_lm_head,
remove_lm_head=remove_lm_head
)
export_config._normalized_config = internal_export_config._normalized_config
return export_config
Expand Down Expand Up @@ -2812,60 +2811,45 @@ class JanusConfigBehavior(str, enum.Enum):


class JanusDummyVisionGenInputGenerator(DummyInputGenerator):
SUPPORTED_INPUT_NAMES = ("pixel_values", "image_ids", "code_b", "image_shape", "lm_hidden_state", "hidden_state")
SUPPORTED_INPUT_NAMES = (
"pixel_values",
"image_ids",
"code_b",
"image_shape",
"lm_hidden_state",
"hidden_state"
)

def __init__(
self,
task: str,
normalized_config: NormalizedConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
**kwargs,
):
self.task = task
self.batch_size = batch_size
self.sequence_length = sequence_length
self.normalized_config = normalized_config

self,
task: str,
normalized_config: NormalizedConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
**kwargs,
):
self.task = task
self.batch_size = batch_size
self.sequence_length = sequence_length
self.normalized_config = normalized_config
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
if input_name == "pixel_values":
return self.random_float_tensor(
[
self.batch_size,
1,
3,
self.normalized_config.config.params.image_size,
self.normalized_config.config.params.image_size,
]
)

return self.random_float_tensor([self.batch_size, 1, 3, self.normalized_config.config.params.image_size, self.normalized_config.config.params.image_size])

if input_name == "image_ids":
return self.random_int_tensor(
[self.sequence_length],
max_value=self.normalized_config.config.params.image_token_size,
framework=framework,
dtype=int_dtype,
)
return self.random_int_tensor([self.sequence_length], max_value=self.normalized_config.config.params.image_token_size, framework=framework, dtype=int_dtype)
if input_name == "code_b":
return self.random_int_tensor(
[self.batch_size, 576],
max_value=self.normalized_config.config.params.image_token_size,
framework=framework,
dtype=int_dtype,
)
return self.random_int_tensor([self.batch_size, 576], max_value=self.normalized_config.config.params.image_token_size, framework=framework, dtype=int_dtype)
if input_name == "image_shape":
import torch

return torch.tensor(
[self.batch_size, self.normalized_config.config.params.n_embed, 24, 24], dtype=torch.int64
)
return torch.tensor([self.batch_size, self.normalized_config.config.params.n_embed, 24, 24], dtype=torch.int64)
if input_name == "hidden_state":
return self.random_float_tensor(
[self.batch_size, self.sequence_length, self.normalized_config.hidden_size]
)
return self.random_float_tensor([self.batch_size, self.sequence_length, self.normalized_config.hidden_size])
if input_name == "lm_hidden_state":
return self.random_float_tensor([self.sequence_length, self.normalized_config.hidden_size])
return super().generate(input_name, framework, int_dtype, float_dtype)



@register_in_tasks_manager("multi-modality", *["image-text-to-text", "any-to-any"], library_name="transformers")
Expand All @@ -2883,7 +2867,7 @@ def __init__(
float_dtype: str = "fp32",
behavior: JanusConfigBehavior = JanusConfigBehavior.VISION_EMBEDDINGS,
preprocessors: Optional[List[Any]] = None,
**kwargs,
**kwargs
):
super().__init__(
config=config,
Expand All @@ -2897,9 +2881,7 @@ def __init__(
if self._behavior == JanusConfigBehavior.VISION_EMBEDDINGS and hasattr(config, "vision_config"):
self._config = config.vision_config
self._normalized_config = NormalizedVisionConfig(self._config)
if self._behavior in [JanusConfigBehavior.LM_HEAD, JanusConfigBehavior.VISION_GEN_HEAD] and hasattr(
config, "language_config"
):
if self._behavior in [JanusConfigBehavior.LM_HEAD, JanusConfigBehavior.VISION_GEN_HEAD] and hasattr(config, "language_config"):
self._config = config.language_config
self._normalized_config = NormalizedTextConfig(self._config)
if self._behavior == JanusConfigBehavior.VISION_GEN_EMBEDDINGS and hasattr(config, "gen_head_config"):
Expand Down Expand Up @@ -2929,7 +2911,7 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
return {"last_hidden_state": {0: "batch_size"}}
if self._behavior == JanusConfigBehavior.VISION_GEN_EMBEDDINGS:
return {"last_hidden_state": {0: "num_tokens"}}

if self._behavior == JanusConfigBehavior.LM_HEAD:
return {"logits": {0: "batch_size", 1: "sequence_length"}}

Expand Down Expand Up @@ -2996,6 +2978,7 @@ def with_behavior(
preprocessors=self._preprocessors,
)


if behavior == JanusConfigBehavior.VISION_EMBEDDINGS:
return self.__class__(
self._orig_config,
Expand All @@ -3005,7 +2988,7 @@ def with_behavior(
behavior=behavior,
preprocessors=self._preprocessors,
)

if behavior == JanusConfigBehavior.VISION_GEN_DECODER:
return self.__class__(
self._orig_config,
Expand All @@ -3016,6 +2999,7 @@ def with_behavior(
preprocessors=self._preprocessors,
)


def get_model_for_behavior(self, model, behavior: Union[str, JanusConfigBehavior]):
if isinstance(behavior, str) and not isinstance(behavior, JanusConfigBehavior):
behavior = JanusConfigBehavior(behavior)
Expand All @@ -3038,7 +3022,7 @@ def get_model_for_behavior(self, model, behavior: Union[str, JanusConfigBehavior

if behavior == JanusConfigBehavior.VISION_GEN_EMBEDDINGS:
return model

if behavior == JanusConfigBehavior.VISION_GEN_HEAD:
gen_head = model.gen_head
gen_head.config = model.language_model.config
Expand All @@ -3047,6 +3031,7 @@ def get_model_for_behavior(self, model, behavior: Union[str, JanusConfigBehavior
if behavior == JanusConfigBehavior.VISION_GEN_DECODER:
return model.gen_vision_model


def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
):
Expand All @@ -3059,6 +3044,7 @@ def patch_model_for_export(
return JanusVisionGenDecoderModelPatcher(self, model, model_kwargs)
return super().patch_model_for_export(model, model_kwargs)


def rename_ambiguous_inputs(self, inputs):
if self._behavior == JanusConfigBehavior.VISION_GEN_HEAD:
data = inputs.pop("lm_hidden_state")
Expand All @@ -3069,4 +3055,4 @@ def rename_ambiguous_inputs(self, inputs):
if self._behavior == JanusConfigBehavior.VISION_GEN_DECODER:
data = inputs.pop("image_shape")
inputs["shape"] = data
return inputs
return inputs
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

INSTALL_REQUIRE = [
"torch>=1.11",
"optimum~=1.24",
"optimum @ git+https://github.com/eaidova/optimum.git@ea/avoid_lib_guessing_in_standartize_args",
"transformers>=4.36,<4.48",
"datasets>=1.4.0",
"sentencepiece",
Expand Down

0 comments on commit 40ef93f

Please sign in to comment.