Skip to content

Commit

Permalink
cached plugin enum (#768)
Browse files Browse the repository at this point in the history
* initial plugin cache

Signed-off-by: Jeffrey Martin <[email protected]>

* plugin cache as class object

Signed-off-by: Jeffrey Martin <[email protected]>

* plugin cache retrieves only attributes in the base plugin type

* enhance singleton class object access
* ensure sorted enueration results
* add cache tests for existing function

* tests for plugin cache

Signed-off-by: Jeffrey Martin <[email protected]>

* sort plugin classes in cache for consistent rebuild order

Signed-off-by: Jeffrey Martin <[email protected]>

* ensure description in metadata and skip `post_buff_hook`

* initialize metadata description as doc string for class
* suppress `post_buff_hook`, may rename in future
* test priority fields against key instead of value

Signed-off-by: Jeffrey Martin <[email protected]>

* ensure all doc strings conform to PEP-257

update class doc strings to [PEP-257 multi-line format](https://peps.python.org/pep-0257/#multi-line-docstrings)

Signed-off-by: Jeffrey Martin <[email protected]>

* add initial packaged cache file

Signed-off-by: Jeffrey Martin <[email protected]>

---------

Signed-off-by: Jeffrey Martin <[email protected]>
  • Loading branch information
jmartin-tech authored Jul 16, 2024
1 parent ef92f12 commit 326d8ba
Show file tree
Hide file tree
Showing 15 changed files with 5,092 additions and 89 deletions.
243 changes: 207 additions & 36 deletions garak/_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,215 @@

import importlib
import inspect
import json
import logging
import shutil
import os
from typing import List
from typing import List, Callable, Union
from pathlib import Path

from garak import _config
from garak.exception import GarakException

PLUGIN_TYPES = ("probes", "detectors", "generators", "harnesses", "buffs")
PLUGIN_CLASSES = ("Probe", "Detector", "Generator", "Harness", "Buff")
TIME_FORMAT = "%Y-%m-%d %H:%M:%S %z"


@staticmethod
def _extract_modules_klasses(base_klass):
return [ # Extract only classes with same source package name
name
for name, klass in inspect.getmembers(base_klass, inspect.isclass)
if klass.__module__.startswith(base_klass.__name__)
]
class PluginEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, set):
return list(obj).sort() # allow set as list, assumes values can be sorted
if isinstance(obj, Path):
# relative path for now, may be better to suppress `Path` objects
return str(obj).replace(str(_config.transient.basedir), "")
try:
return json.JSONEncoder.default(self, obj)
except TypeError as e:
logging.debug("Attempt to serialize JSON skipped: %s", e)
return None # skip items that cannot be serialized at this time


class PluginCache:
_plugin_cache_file = _config.transient.basedir / "resources" / "plugin_cache.json"
_user_plugin_cache_file = _plugin_cache_file
_plugin_cache_dict = None

def __init__(self) -> None:
if PluginCache._plugin_cache_dict is None:
PluginCache._plugin_cache_dict = self._load_plugin_cache()

@staticmethod
def _extract_modules_klasses(base_klass):
return [ # Extract only classes with same source package name
name
for name, klass in inspect.getmembers(base_klass, inspect.isclass)
if klass.__module__.startswith(base_klass.__name__)
]

def _load_plugin_cache(self):
if not os.path.exists(self._plugin_cache_file):
self._build_plugin_cache()
if not os.path.exists(self._user_plugin_cache_file):
shutil.copy2(self._plugin_cache_file, self._user_plugin_cache_file)
with open(self._user_plugin_cache_file, "r", encoding="utf-8") as cache_file:
local_cache = json.load(cache_file)
return local_cache

def _build_plugin_cache(self):
"""build a plugin cache file to improve access times
This method writes only to the user's cache (currently the same as the system cache)
TODO: Enhance location of user cache to enable support for in development plugins.
"""
local_cache = {}

for plugin_type in PLUGIN_TYPES:
plugin_dict = {}
for plugin in self._enumerate_plugin_klasses(plugin_type):
plugin_name = ".".join([plugin.__module__, plugin.__name__]).replace(
"garak.", ""
)
plugin_dict[plugin_name] = PluginCache.plugin_info(plugin)

sorted_keys = sorted(list(plugin_dict.keys()))
local_cache[plugin_type] = {i: plugin_dict[i] for i in sorted_keys}

with open(self._user_plugin_cache_file, "w", encoding="utf-8") as cache_file:
json.dump(local_cache, cache_file, cls=PluginEncoder, indent=2)

def _enumerate_plugin_klasses(self, category: str) -> List[Callable]:
"""obtain all"""
if category not in PLUGIN_TYPES:
raise ValueError("Not a recognised plugin type:", category)

base_mod = importlib.import_module(f"garak.{category}.base")

base_plugin_classnames = set(self._extract_modules_klasses(base_mod))

module_plugin_names = set()

for module_filename in sorted(os.listdir(_config.transient.basedir / category)):
if not module_filename.endswith(".py"):
continue
if module_filename.startswith("__"):
continue
module_name = module_filename.replace(".py", "")
mod = importlib.import_module(f"garak.{category}.{module_name}")
module_entries = set(self._extract_modules_klasses(mod))

for module_entry in module_entries:
obj = getattr(mod, module_entry)
for interface in base_plugin_classnames:
klass = getattr(base_mod, interface)
if issubclass(obj, klass):
module_plugin_names.add(obj)

return module_plugin_names

def instance() -> dict:
return PluginCache()._plugin_cache_dict

def plugin_info(plugin: Union[Callable, str]) -> dict:
"""retrieves the standard attributes for the plugin type"""
if isinstance(plugin, str):
plugin_name = plugin
category = plugin_name.split(".")[0]

if category not in PLUGIN_TYPES:
raise ValueError(f"Not a recognised plugin type: {category}")

plugin_metadata = PluginCache.instance()[category].get(plugin_name, {})
if len(plugin_metadata) > 0:
return plugin_metadata
else:
# the requested plugin is not cached import the class for eval
parts = plugin.split(".")
match len(parts):
case 3:
try:
module = ".".join(parts[:-1])
klass = parts[-1]
imported_module = importlib.import_module(f"garak.{module}")
plugin = getattr(imported_module, klass)
except (AttributeError, ModuleNotFoundError) as e:
if isinstance(e, AttributeError):
msg = f"Not a recognised plugin from {module}: {klass}"
else:
msg = f"Not a recognised plugin module: {plugin}"
raise ValueError(msg)
case _:
raise ValueError(f"Not a recognised plugin class: {plugin}")
else:
plugin_name = ".".join([plugin.__module__, plugin.__name__]).replace(
"garak.", ""
)
category = plugin_name.split(".")[0]

try:
base_attributes = []
base_mod = importlib.import_module(f"garak.{category}.base")
base_plugin_classes = set(PluginCache._extract_modules_klasses(base_mod))
if plugin.__module__ in base_mod.__name__:
# this is a base class enumerate all
base_attributes = dir(plugin)
else:
for klass in base_plugin_classes:
# filter to the base class actually implemented
if issubclass(plugin, getattr(base_mod, klass)):
base_attributes += PluginCache.plugin_info(
getattr(base_mod, klass)
).keys()

plugin_metadata = {}
priority_fields = ["description"]
skip_fields = [
"prompts",
"triggers",
"post_buff_hook",
]

# description as doc string will be overwritten if provided by the class
desc = plugin.__doc__
if desc is not None:
plugin_metadata["description"] = desc.split("\n")[0]

for v in priority_fields:
if hasattr(plugin, v):
plugin_metadata[v] = getattr(plugin, v)
for v in sorted(dir(plugin)):
if v in priority_fields or v in skip_fields:
continue
value = getattr(plugin, v)
if (
v.startswith("_")
or inspect.ismethod(value)
or inspect.isfunction(value)
or v not in base_attributes
):
continue
plugin_metadata[v] = value

except ValueError as e:
logging.exception(e)
except Exception as e:
logging.error(f"Plugin {plugin_name} not found.")
logging.exception(e)

from datetime import datetime, timezone

# adding last class modification time to cache allows for targeted update in future
current_mod = importlib.import_module(plugin.__module__)
mod_time = datetime.fromtimestamp(
os.path.getmtime(current_mod.__file__), tz=timezone.utc
)
plugin_metadata["mod_time"] = mod_time.strftime(TIME_FORMAT)

return plugin_metadata


def plugin_info(plugin: Union[Callable, str]) -> dict:
return PluginCache.plugin_info(plugin)


def enumerate_plugins(
Expand All @@ -49,37 +240,17 @@ def enumerate_plugins(

base_mod = importlib.import_module(f"garak.{category}.base")

base_plugin_classnames = set(_extract_modules_klasses(base_mod))
base_plugin_classnames = set(PluginCache._extract_modules_klasses(base_mod))

plugin_class_names = []
plugin_class_names = set()

for module_filename in sorted(os.listdir(_config.transient.basedir / category)):
if not module_filename.endswith(".py"):
for k, v in PluginCache.instance()[category].items():
if skip_base_classes and k.split(".")[-1] in base_plugin_classnames:
continue
if module_filename.startswith("__"):
continue
if module_filename == "base.py" and skip_base_classes:
continue
module_name = module_filename.replace(".py", "")
mod = importlib.import_module(
f"garak.{category}.{module_name}"
) # import here will access all namespace level imports consider a cache to speed up processing
module_entries = set(_extract_modules_klasses(mod))
if skip_base_classes:
module_entries = module_entries.difference(base_plugin_classnames)
module_plugin_names = set()
for module_entry in module_entries:
obj = getattr(mod, module_entry)
for interface in base_plugin_classnames:
klass = getattr(base_mod, interface)
if issubclass(obj, klass):
module_plugin_names.add((module_entry, obj.active))

for module_plugin_name, active in sorted(module_plugin_names):
plugin_class_names.append(
(f"{category}.{module_name}.{module_plugin_name}", active)
)
return plugin_class_names
enum_entry = (k, v["active"])
plugin_class_names.add(enum_entry)

return sorted(plugin_class_names)


def load_plugin(path, break_on_fail=True, config_root=_config) -> object:
Expand Down
6 changes: 2 additions & 4 deletions garak/attempt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@


class Attempt:
"""A class defining objects that represent everything that constitutes
a single attempt at evaluating an LLM.
"""A class defining objects that represent everything that constitutes a single attempt at evaluating an LLM.
:param status: The status of this attempt; ``ATTEMPT_NEW``, ``ATTEMPT_STARTED``, or ``ATTEMPT_COMPLETE``
:type status: int
Expand Down Expand Up @@ -171,8 +170,7 @@ def __getattribute__(self, name: str) -> Any:
return super().__getattribute__(name)

def __setattr__(self, name: str, value: Any) -> None:
"""override prompt and outputs access to take from history
NB. output elements need to be able to be None"""
"""override prompt and outputs access to take from history NB. output elements need to be able to be None"""

if name == "prompt":
if value is None:
Expand Down
30 changes: 10 additions & 20 deletions garak/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,30 +176,20 @@ def print_buffs():

# describe plugin
def plugin_info(plugin_name):
import inspect
from garak._plugins import plugin_info

from garak._plugins import load_plugin

# load plugin
try:
plugin = load_plugin(plugin_name)
info = plugin_info(plugin_name)
if len(info) > 0:
print(f"Configured info on {plugin_name}:")
priority_fields = ["description"]
skip_fields = ["prompts", "triggers"]
# print the attribs it has
for v in priority_fields:
print(f"{v:>35}:", getattr(plugin, v))
for v in sorted(dir(plugin)):
if v in priority_fields or v in skip_fields:
continue
if v.startswith("_") or inspect.ismethod(getattr(plugin, v)):
for k in priority_fields:
if k in info:
print(f"{k:>35}:", info[k])
for k, v in info.items():
if k in priority_fields:
continue
print(f"{v:>35}:", getattr(plugin, v))

except ValueError as e:
print(e)
except Exception as e:
print(e)
print(f"{k:>35}:", v)
else:
print(
f"Plugin {plugin_name} not found. Try --list_probes, or --list_detectors."
)
Expand Down
12 changes: 6 additions & 6 deletions garak/detectors/misleading.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@


class MustRefuteClaimModel(HFDetector):
"""
This detector is looking for the model to refute a claim. A "hit" occurs when the generation
fails to include a refutation.
"""This detector is looking for the model to refute a claim.
A "hit" occurs when the generation fails to include a refutation.
"""

description = (
Expand Down Expand Up @@ -48,9 +48,9 @@ def detect(self, attempt: Attempt) -> List[float]:


class MustContradictNLI(HFDetector):
"""
This detector is looking for the model to contradict a claim. A "hit" occurs when the generation
fails to include a contradiction.
"""This detector is looking for the model to contradict a claim.
A "hit" occurs when the generation fails to include a contradiction.
Using ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli to detect model refutations
"""
Expand Down
4 changes: 3 additions & 1 deletion garak/detectors/mitigation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@


class MitigationBypass(StringDetector):
"""Use where the model is expected to put out a mitigation message
"""Detect if a string response does not contain a mitigation message.
Use where the model is expected to put out a mitigation message
instead of complying. Returns a success if no mitigation message is
detected. Runs using a set of keywords found in OpenAI ChatGPT
mitigation messages listed in the ShareGPT dataset used in training
Expand Down
3 changes: 1 addition & 2 deletions garak/detectors/packagehallucination.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@


class PythonPypi(Detector):
"""Check if the output tries to import a package not listed
in stdlib or a pypi archive listing"""
"""Check if the output tries to import a package not listed in stdlib or a pypi archive listing"""

pypi_packages = None
pypi_dataset_name = "strombergnlp/pypi-20230724"
Expand Down
1 change: 1 addition & 0 deletions garak/generators/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,7 @@ def _pre_generate_hook(self):

class InferenceEndpoint(InferenceAPI):
"""Interface for Hugging Face private endpoints
Pass the model URL as the name, e.g. https://xxx.aws.endpoints.huggingface.cloud
"""

Expand Down
4 changes: 1 addition & 3 deletions garak/generators/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,7 @@


class LiteLLMGenerator(Generator):
"""Generator wrapper using LiteLLM to allow access to different
providers using the OpenAI API format.
"""
"""Generator wrapper using LiteLLM to allow access to different providers using the OpenAI API format."""

ENV_VAR = "OPENAI_API_KEY"
DEFAULT_PARAMS = Generator.DEFAULT_PARAMS | {
Expand Down
Loading

0 comments on commit 326d8ba

Please sign in to comment.