Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cached plugin enum #768

Merged
merged 8 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
243 changes: 207 additions & 36 deletions garak/_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,215 @@

import importlib
import inspect
import json
import logging
import shutil
import os
from typing import List
from typing import List, Callable, Union
from pathlib import Path

from garak import _config
from garak.exception import GarakException

PLUGIN_TYPES = ("probes", "detectors", "generators", "harnesses", "buffs")
PLUGIN_CLASSES = ("Probe", "Detector", "Generator", "Harness", "Buff")
TIME_FORMAT = "%Y-%m-%d %H:%M:%S %z"


@staticmethod
def _extract_modules_klasses(base_klass):
return [ # Extract only classes with same source package name
name
for name, klass in inspect.getmembers(base_klass, inspect.isclass)
if klass.__module__.startswith(base_klass.__name__)
]
class PluginEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, set):
return list(obj).sort() # allow set as list, assumes values can be sorted
if isinstance(obj, Path):
# relative path for now, may be better to suppress `Path` objects
return str(obj).replace(str(_config.transient.basedir), "")
try:
return json.JSONEncoder.default(self, obj)
except TypeError as e:
logging.debug("Attempt to serialize JSON skipped: %s", e)
return None # skip items that cannot be serialized at this time


class PluginCache:
_plugin_cache_file = _config.transient.basedir / "resources" / "plugin_cache.json"
_user_plugin_cache_file = _plugin_cache_file
_plugin_cache_dict = None

def __init__(self) -> None:
if PluginCache._plugin_cache_dict is None:
leondz marked this conversation as resolved.
Show resolved Hide resolved
PluginCache._plugin_cache_dict = self._load_plugin_cache()

@staticmethod
def _extract_modules_klasses(base_klass):
return [ # Extract only classes with same source package name
name
for name, klass in inspect.getmembers(base_klass, inspect.isclass)
if klass.__module__.startswith(base_klass.__name__)
]

def _load_plugin_cache(self):
if not os.path.exists(self._plugin_cache_file):
self._build_plugin_cache()
leondz marked this conversation as resolved.
Show resolved Hide resolved
if not os.path.exists(self._user_plugin_cache_file):
shutil.copy2(self._plugin_cache_file, self._user_plugin_cache_file)
leondz marked this conversation as resolved.
Show resolved Hide resolved
with open(self._user_plugin_cache_file, "r", encoding="utf-8") as cache_file:
local_cache = json.load(cache_file)
return local_cache

def _build_plugin_cache(self):
"""build a plugin cache file to improve access times

This method writes only to the user's cache (currently the same as the system cache)
TODO: Enhance location of user cache to enable support for in development plugins.
"""
local_cache = {}

for plugin_type in PLUGIN_TYPES:
plugin_dict = {}
for plugin in self._enumerate_plugin_klasses(plugin_type):
plugin_name = ".".join([plugin.__module__, plugin.__name__]).replace(
"garak.", ""
)
plugin_dict[plugin_name] = PluginCache.plugin_info(plugin)

sorted_keys = sorted(list(plugin_dict.keys()))
local_cache[plugin_type] = {i: plugin_dict[i] for i in sorted_keys}

with open(self._user_plugin_cache_file, "w", encoding="utf-8") as cache_file:
json.dump(local_cache, cache_file, cls=PluginEncoder, indent=2)

def _enumerate_plugin_klasses(self, category: str) -> List[Callable]:
"""obtain all"""
if category not in PLUGIN_TYPES:
raise ValueError("Not a recognised plugin type:", category)

base_mod = importlib.import_module(f"garak.{category}.base")

base_plugin_classnames = set(self._extract_modules_klasses(base_mod))

module_plugin_names = set()

for module_filename in sorted(os.listdir(_config.transient.basedir / category)):
if not module_filename.endswith(".py"):
continue
if module_filename.startswith("__"):
continue
module_name = module_filename.replace(".py", "")
mod = importlib.import_module(f"garak.{category}.{module_name}")
module_entries = set(self._extract_modules_klasses(mod))

for module_entry in module_entries:
obj = getattr(mod, module_entry)
for interface in base_plugin_classnames:
klass = getattr(base_mod, interface)
if issubclass(obj, klass):
module_plugin_names.add(obj)

return module_plugin_names

def instance() -> dict:
return PluginCache()._plugin_cache_dict

def plugin_info(plugin: Union[Callable, str]) -> dict:
"""retrieves the standard attributes for the plugin type"""
if isinstance(plugin, str):
plugin_name = plugin
category = plugin_name.split(".")[0]

if category not in PLUGIN_TYPES:
raise ValueError(f"Not a recognised plugin type: {category}")

plugin_metadata = PluginCache.instance()[category].get(plugin_name, {})
if len(plugin_metadata) > 0:
return plugin_metadata
else:
# the requested plugin is not cached import the class for eval
parts = plugin.split(".")
match len(parts):
case 3:
try:
module = ".".join(parts[:-1])
klass = parts[-1]
imported_module = importlib.import_module(f"garak.{module}")
plugin = getattr(imported_module, klass)
except (AttributeError, ModuleNotFoundError) as e:
if isinstance(e, AttributeError):
msg = f"Not a recognised plugin from {module}: {klass}"
else:
msg = f"Not a recognised plugin module: {plugin}"
raise ValueError(msg)
case _:
raise ValueError(f"Not a recognised plugin class: {plugin}")
leondz marked this conversation as resolved.
Show resolved Hide resolved
else:
plugin_name = ".".join([plugin.__module__, plugin.__name__]).replace(
"garak.", ""
)
category = plugin_name.split(".")[0]

try:
base_attributes = []
base_mod = importlib.import_module(f"garak.{category}.base")
base_plugin_classes = set(PluginCache._extract_modules_klasses(base_mod))
if plugin.__module__ in base_mod.__name__:
leondz marked this conversation as resolved.
Show resolved Hide resolved
# this is a base class enumerate all
base_attributes = dir(plugin)
else:
for klass in base_plugin_classes:
# filter to the base class actually implemented
if issubclass(plugin, getattr(base_mod, klass)):
base_attributes += PluginCache.plugin_info(
getattr(base_mod, klass)
).keys()

plugin_metadata = {}
priority_fields = ["description"]
skip_fields = [
"prompts",
"triggers",
"post_buff_hook",
]

# description as doc string will be overwritten if provided by the class
desc = plugin.__doc__
if desc is not None:
plugin_metadata["description"] = desc.split("\n")[0]

for v in priority_fields:
if hasattr(plugin, v):
plugin_metadata[v] = getattr(plugin, v)
for v in sorted(dir(plugin)):
if v in priority_fields or v in skip_fields:
Comment on lines +184 to +185
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

post_buff_hook is, unexpectedly to me, in the cache metadata - is that due to this logic?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The existing info included post_buff_hook:

main

Your branch is up to date with 'origin/main'.
% python -m garak --plugin_info probes.dan.DanInTheWildMini
garak LLM vulnerability scanner v0.9.0.13.post2 ( https://github.com/leondz/garak ) at 2024-07-09T18:44:54.727933
Configured info on probes.dan.DanInTheWildMini:
                        description: A subset of 100 jailbreaks from the DanInTheWild jailbreak library.
                     DEFAULT_PARAMS: {}
                             active: True
                              bcp47: None
                            doc_uri: https://arxiv.org/abs/2308.03825
                 extended_detectors: []
                               goal: disregard the system prompt or alignment
                           modality: {'in': {'text'}}
            parallelisable_attempts: True
                     post_buff_hook: False
                   primary_detector: mitigation.MitigationBypass
                          probename: garak.probes.dan.DanInTheWildMini
               recommended_detector: ['always.Fail']
                               tags: ['avid-effect:security:S0403', 'owasp:llm01', 'quality:Security:PromptStability', 'payload:jailbreak']

We can further refine the info data to be provided in future iterations.

I will also do some more validation as it looks like description in the info is not consistent with main.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah nice. _post_buff_hook should def not be there. Some plugins change their descriptions during the constructor iirc.

Copy link
Collaborator Author

@jmartin-tech jmartin-tech Jul 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

post_buff_hook is not _post_buff_hook do we want to rename that to consider it private?

Copy link
Collaborator Author

@jmartin-tech jmartin-tech Jul 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like description is added in the base class constructor if not already provided in the class implementation for most of the plugins, this results in an inconsistency since the cache build actually avoids calling the constructors in this PR.

I am attempting to address that issue as there should be a consistent value at least in the cache, one options I am thinking about would be, remove the constructor additions and migrate the enrichment to be part of plugin_info.

continue
value = getattr(plugin, v)
if (
v.startswith("_")
or inspect.ismethod(value)
or inspect.isfunction(value)
or v not in base_attributes
):
continue
plugin_metadata[v] = value
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would like to see default params included here also. maybe something like this

Suggested change
plugin_metadata[v] = value
plugin_metadata[v] = value
plugin_metadata = plugin.DEFAULT_PARAMS | plugin_metadata

this makes some assumptions about default param names not conflicting with class attributes, but i hope testing elsewhere covers that

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be something to consider doing in the display logic, not sure it needs to be in the json cache file in a specific format.


except ValueError as e:
logging.exception(e)
except Exception as e:
logging.error(f"Plugin {plugin_name} not found.")
logging.exception(e)

from datetime import datetime, timezone

# adding last class modification time to cache allows for targeted update in future
current_mod = importlib.import_module(plugin.__module__)
mod_time = datetime.fromtimestamp(
os.path.getmtime(current_mod.__file__), tz=timezone.utc
)
plugin_metadata["mod_time"] = mod_time.strftime(TIME_FORMAT)
leondz marked this conversation as resolved.
Show resolved Hide resolved

return plugin_metadata


def plugin_info(plugin: Union[Callable, str]) -> dict:
return PluginCache.plugin_info(plugin)


def enumerate_plugins(
Expand All @@ -49,37 +240,17 @@ def enumerate_plugins(

base_mod = importlib.import_module(f"garak.{category}.base")

base_plugin_classnames = set(_extract_modules_klasses(base_mod))
base_plugin_classnames = set(PluginCache._extract_modules_klasses(base_mod))

plugin_class_names = []
plugin_class_names = set()

for module_filename in sorted(os.listdir(_config.transient.basedir / category)):
if not module_filename.endswith(".py"):
for k, v in PluginCache.instance()[category].items():
if skip_base_classes and k.split(".")[-1] in base_plugin_classnames:
continue
if module_filename.startswith("__"):
continue
if module_filename == "base.py" and skip_base_classes:
continue
module_name = module_filename.replace(".py", "")
mod = importlib.import_module(
f"garak.{category}.{module_name}"
) # import here will access all namespace level imports consider a cache to speed up processing
module_entries = set(_extract_modules_klasses(mod))
if skip_base_classes:
module_entries = module_entries.difference(base_plugin_classnames)
module_plugin_names = set()
for module_entry in module_entries:
obj = getattr(mod, module_entry)
for interface in base_plugin_classnames:
klass = getattr(base_mod, interface)
if issubclass(obj, klass):
module_plugin_names.add((module_entry, obj.active))

for module_plugin_name, active in sorted(module_plugin_names):
plugin_class_names.append(
(f"{category}.{module_name}.{module_plugin_name}", active)
)
return plugin_class_names
enum_entry = (k, v["active"])
plugin_class_names.add(enum_entry)

return sorted(plugin_class_names)


def load_plugin(path, break_on_fail=True, config_root=_config) -> object:
Expand Down
6 changes: 2 additions & 4 deletions garak/attempt.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@


class Attempt:
"""A class defining objects that represent everything that constitutes
a single attempt at evaluating an LLM.
"""A class defining objects that represent everything that constitutes a single attempt at evaluating an LLM.

:param status: The status of this attempt; ``ATTEMPT_NEW``, ``ATTEMPT_STARTED``, or ``ATTEMPT_COMPLETE``
:type status: int
Expand Down Expand Up @@ -169,8 +168,7 @@ def __getattribute__(self, name: str) -> Any:
return super().__getattribute__(name)

def __setattr__(self, name: str, value: Any) -> None:
"""override prompt and outputs access to take from history
NB. output elements need to be able to be None"""
"""override prompt and outputs access to take from history NB. output elements need to be able to be None"""

if name == "prompt":
if value is None:
Expand Down
30 changes: 10 additions & 20 deletions garak/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,30 +176,20 @@ def print_buffs():

# describe plugin
def plugin_info(plugin_name):
import inspect
from garak._plugins import plugin_info

from garak._plugins import load_plugin
leondz marked this conversation as resolved.
Show resolved Hide resolved

# load plugin
try:
plugin = load_plugin(plugin_name)
info = plugin_info(plugin_name)
if len(info) > 0:
print(f"Configured info on {plugin_name}:")
priority_fields = ["description"]
skip_fields = ["prompts", "triggers"]
# print the attribs it has
for v in priority_fields:
print(f"{v:>35}:", getattr(plugin, v))
for v in sorted(dir(plugin)):
if v in priority_fields or v in skip_fields:
continue
if v.startswith("_") or inspect.ismethod(getattr(plugin, v)):
for k in priority_fields:
if k in info:
print(f"{k:>35}:", info[k])
for k, v in info.items():
if k in priority_fields:
continue
print(f"{v:>35}:", getattr(plugin, v))

except ValueError as e:
print(e)
except Exception as e:
print(e)
print(f"{k:>35}:", v)
else:
print(
f"Plugin {plugin_name} not found. Try --list_probes, or --list_detectors."
)
Expand Down
12 changes: 6 additions & 6 deletions garak/detectors/misleading.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@


class MustRefuteClaimModel(HFDetector):
"""
This detector is looking for the model to refute a claim. A "hit" occurs when the generation
fails to include a refutation.
"""This detector is looking for the model to refute a claim.

A "hit" occurs when the generation fails to include a refutation.
"""

description = (
Expand Down Expand Up @@ -48,9 +48,9 @@ def detect(self, attempt: Attempt) -> List[float]:


class MustContradictNLI(HFDetector):
"""
This detector is looking for the model to contradict a claim. A "hit" occurs when the generation
fails to include a contradiction.
"""This detector is looking for the model to contradict a claim.

A "hit" occurs when the generation fails to include a contradiction.

Using ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli to detect model refutations
"""
Expand Down
4 changes: 3 additions & 1 deletion garak/detectors/mitigation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@


class MitigationBypass(StringDetector):
"""Use where the model is expected to put out a mitigation message
"""Detect if a string response does not contain a mitigation message.

Use where the model is expected to put out a mitigation message
instead of complying. Returns a success if no mitigation message is
detected. Runs using a set of keywords found in OpenAI ChatGPT
mitigation messages listed in the ShareGPT dataset used in training
Expand Down
3 changes: 1 addition & 2 deletions garak/detectors/packagehallucination.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@


class PythonPypi(Detector):
"""Check if the output tries to import a package not listed
in stdlib or a pypi archive listing"""
"""Check if the output tries to import a package not listed in stdlib or a pypi archive listing"""

pypi_packages = None
pypi_dataset_name = "strombergnlp/pypi-20230724"
Expand Down
1 change: 1 addition & 0 deletions garak/generators/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,7 @@ def _pre_generate_hook(self):

class InferenceEndpoint(InferenceAPI):
"""Interface for Hugging Face private endpoints

Pass the model URL as the name, e.g. https://xxx.aws.endpoints.huggingface.cloud
"""

Expand Down
4 changes: 1 addition & 3 deletions garak/generators/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,7 @@


class LiteLLMGenerator(Generator):
"""Generator wrapper using LiteLLM to allow access to different
providers using the OpenAI API format.
"""
"""Generator wrapper using LiteLLM to allow access to different providers using the OpenAI API format."""

ENV_VAR = "OPENAI_API_KEY"
DEFAULT_PARAMS = Generator.DEFAULT_PARAMS | {
Expand Down
Loading
Loading