Skip to content

Commit

Permalink
make llava-relevant modules load on instantiation; load/unload them a…
Browse files Browse the repository at this point in the history
…round pickling
  • Loading branch information
leondz committed Jun 12, 2024
1 parent ba28a06 commit 184cca5
Showing 1 changed file with 28 additions and 7 deletions.
35 changes: 28 additions & 7 deletions garak/generators/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,20 +578,40 @@ class LLaVA(Generator):
"llava-hf/llava-v1.6-mistral-7b-hf",
]

# avoid attempt to pickle the client attribute
def __getstate__(self) -> object:
self._clear_client()
return dict(self.__dict__)

# restore the client attribute
def __setstate__(self, d) -> object:
self.__dict__.update(d)
self._load_client()

def _load_client(self):
PIL = importlib.import_module("PIL")
self.Image = PIL.Image

transformers = importlib.import_module("transformers")
self.LlavaNextProcessor = transformers.LlavaNextProcessor
self.LlavaNextForConditionalGeneration = (
transformers.LlavaNextForConditionalGeneration
)

def _clear_client(self):
self.Image = None
self.LlavaNextProcessor = None
self.LlavaNextForConditionalGeneration = None

def __init__(self, name="", generations=10, config_root=_config):
super().__init__(name, generations=generations, config_root=config_root)
if self.name not in self.supported_models:
raise ModelNameMissingError(
f"Invalid modal name {self.name}, current support: {self.supported_models}."
)

from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration

PIL = importlib.import_module("PIL")
self.Image = PIL.Image

self.processor = LlavaNextProcessor.from_pretrained(self.name)
self.model = LlavaNextForConditionalGeneration.from_pretrained(
self.processor = self.LlavaNextProcessor.from_pretrained(self.name)
self.model = self.LlavaNextForConditionalGeneration.from_pretrained(
self.name,
torch_dtype=self.torch_dtype,
low_cpu_mem_usage=self.low_cpu_mem_usage,
Expand All @@ -602,6 +622,7 @@ def __init__(self, name="", generations=10, config_root=_config):
raise RuntimeError(
"CUDA is not supported on this device. Please make sure CUDA is installed and configured properly."
)
self._load_client()

def generate(
self, prompt: str, generations_this_call: int = 1
Expand Down

0 comments on commit 184cca5

Please sign in to comment.