Skip to content

Commit

Permalink
Merge branch 'main' into update/config_nested_dicts
Browse files Browse the repository at this point in the history
  • Loading branch information
leondz committed Jul 18, 2024
2 parents fc06af8 + 326d8ba commit 637ffa1
Show file tree
Hide file tree
Showing 21 changed files with 5,174 additions and 109 deletions.
12 changes: 1 addition & 11 deletions .github/workflows/test_windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,6 @@ jobs:
with:
path: garak

- name: Checkout ecoji for modified windows install
uses: actions/checkout@v3
with:
repository: mecforlove/ecoji-py
path: ecoji-py

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
Expand All @@ -27,11 +21,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
cd ecoji-py
echo "mitigate" > README.md
pip install setuptools
python setup.py install
cd ../garak
cd garak
pip install -r requirements.txt
- name: Test with pytest
Expand Down
243 changes: 207 additions & 36 deletions garak/_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,215 @@

import importlib
import inspect
import json
import logging
import shutil
import os
from typing import List
from typing import List, Callable, Union
from pathlib import Path

from garak import _config
from garak.exception import GarakException

PLUGIN_TYPES = ("probes", "detectors", "generators", "harnesses", "buffs")
PLUGIN_CLASSES = ("Probe", "Detector", "Generator", "Harness", "Buff")
TIME_FORMAT = "%Y-%m-%d %H:%M:%S %z"


@staticmethod
def _extract_modules_klasses(base_klass):
return [ # Extract only classes with same source package name
name
for name, klass in inspect.getmembers(base_klass, inspect.isclass)
if klass.__module__.startswith(base_klass.__name__)
]
class PluginEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, set):
return list(obj).sort() # allow set as list, assumes values can be sorted
if isinstance(obj, Path):
# relative path for now, may be better to suppress `Path` objects
return str(obj).replace(str(_config.transient.basedir), "")
try:
return json.JSONEncoder.default(self, obj)
except TypeError as e:
logging.debug("Attempt to serialize JSON skipped: %s", e)
return None # skip items that cannot be serialized at this time


class PluginCache:
_plugin_cache_file = _config.transient.basedir / "resources" / "plugin_cache.json"
_user_plugin_cache_file = _plugin_cache_file
_plugin_cache_dict = None

def __init__(self) -> None:
if PluginCache._plugin_cache_dict is None:
PluginCache._plugin_cache_dict = self._load_plugin_cache()

@staticmethod
def _extract_modules_klasses(base_klass):
return [ # Extract only classes with same source package name
name
for name, klass in inspect.getmembers(base_klass, inspect.isclass)
if klass.__module__.startswith(base_klass.__name__)
]

def _load_plugin_cache(self):
if not os.path.exists(self._plugin_cache_file):
self._build_plugin_cache()
if not os.path.exists(self._user_plugin_cache_file):
shutil.copy2(self._plugin_cache_file, self._user_plugin_cache_file)
with open(self._user_plugin_cache_file, "r", encoding="utf-8") as cache_file:
local_cache = json.load(cache_file)
return local_cache

def _build_plugin_cache(self):
"""build a plugin cache file to improve access times
This method writes only to the user's cache (currently the same as the system cache)
TODO: Enhance location of user cache to enable support for in development plugins.
"""
local_cache = {}

for plugin_type in PLUGIN_TYPES:
plugin_dict = {}
for plugin in self._enumerate_plugin_klasses(plugin_type):
plugin_name = ".".join([plugin.__module__, plugin.__name__]).replace(
"garak.", ""
)
plugin_dict[plugin_name] = PluginCache.plugin_info(plugin)

sorted_keys = sorted(list(plugin_dict.keys()))
local_cache[plugin_type] = {i: plugin_dict[i] for i in sorted_keys}

with open(self._user_plugin_cache_file, "w", encoding="utf-8") as cache_file:
json.dump(local_cache, cache_file, cls=PluginEncoder, indent=2)

def _enumerate_plugin_klasses(self, category: str) -> List[Callable]:
"""obtain all"""
if category not in PLUGIN_TYPES:
raise ValueError("Not a recognised plugin type:", category)

base_mod = importlib.import_module(f"garak.{category}.base")

base_plugin_classnames = set(self._extract_modules_klasses(base_mod))

module_plugin_names = set()

for module_filename in sorted(os.listdir(_config.transient.basedir / category)):
if not module_filename.endswith(".py"):
continue
if module_filename.startswith("__"):
continue
module_name = module_filename.replace(".py", "")
mod = importlib.import_module(f"garak.{category}.{module_name}")
module_entries = set(self._extract_modules_klasses(mod))

for module_entry in module_entries:
obj = getattr(mod, module_entry)
for interface in base_plugin_classnames:
klass = getattr(base_mod, interface)
if issubclass(obj, klass):
module_plugin_names.add(obj)

return module_plugin_names

def instance() -> dict:
return PluginCache()._plugin_cache_dict

def plugin_info(plugin: Union[Callable, str]) -> dict:
"""retrieves the standard attributes for the plugin type"""
if isinstance(plugin, str):
plugin_name = plugin
category = plugin_name.split(".")[0]

if category not in PLUGIN_TYPES:
raise ValueError(f"Not a recognised plugin type: {category}")

plugin_metadata = PluginCache.instance()[category].get(plugin_name, {})
if len(plugin_metadata) > 0:
return plugin_metadata
else:
# the requested plugin is not cached import the class for eval
parts = plugin.split(".")
match len(parts):
case 3:
try:
module = ".".join(parts[:-1])
klass = parts[-1]
imported_module = importlib.import_module(f"garak.{module}")
plugin = getattr(imported_module, klass)
except (AttributeError, ModuleNotFoundError) as e:
if isinstance(e, AttributeError):
msg = f"Not a recognised plugin from {module}: {klass}"
else:
msg = f"Not a recognised plugin module: {plugin}"
raise ValueError(msg)
case _:
raise ValueError(f"Not a recognised plugin class: {plugin}")
else:
plugin_name = ".".join([plugin.__module__, plugin.__name__]).replace(
"garak.", ""
)
category = plugin_name.split(".")[0]

try:
base_attributes = []
base_mod = importlib.import_module(f"garak.{category}.base")
base_plugin_classes = set(PluginCache._extract_modules_klasses(base_mod))
if plugin.__module__ in base_mod.__name__:
# this is a base class enumerate all
base_attributes = dir(plugin)
else:
for klass in base_plugin_classes:
# filter to the base class actually implemented
if issubclass(plugin, getattr(base_mod, klass)):
base_attributes += PluginCache.plugin_info(
getattr(base_mod, klass)
).keys()

plugin_metadata = {}
priority_fields = ["description"]
skip_fields = [
"prompts",
"triggers",
"post_buff_hook",
]

# description as doc string will be overwritten if provided by the class
desc = plugin.__doc__
if desc is not None:
plugin_metadata["description"] = desc.split("\n")[0]

for v in priority_fields:
if hasattr(plugin, v):
plugin_metadata[v] = getattr(plugin, v)
for v in sorted(dir(plugin)):
if v in priority_fields or v in skip_fields:
continue
value = getattr(plugin, v)
if (
v.startswith("_")
or inspect.ismethod(value)
or inspect.isfunction(value)
or v not in base_attributes
):
continue
plugin_metadata[v] = value

except ValueError as e:
logging.exception(e)
except Exception as e:
logging.error(f"Plugin {plugin_name} not found.")
logging.exception(e)

from datetime import datetime, timezone

# adding last class modification time to cache allows for targeted update in future
current_mod = importlib.import_module(plugin.__module__)
mod_time = datetime.fromtimestamp(
os.path.getmtime(current_mod.__file__), tz=timezone.utc
)
plugin_metadata["mod_time"] = mod_time.strftime(TIME_FORMAT)

return plugin_metadata


def plugin_info(plugin: Union[Callable, str]) -> dict:
return PluginCache.plugin_info(plugin)


def enumerate_plugins(
Expand All @@ -49,37 +240,17 @@ def enumerate_plugins(

base_mod = importlib.import_module(f"garak.{category}.base")

base_plugin_classnames = set(_extract_modules_klasses(base_mod))
base_plugin_classnames = set(PluginCache._extract_modules_klasses(base_mod))

plugin_class_names = []
plugin_class_names = set()

for module_filename in sorted(os.listdir(_config.transient.basedir / category)):
if not module_filename.endswith(".py"):
for k, v in PluginCache.instance()[category].items():
if skip_base_classes and k.split(".")[-1] in base_plugin_classnames:
continue
if module_filename.startswith("__"):
continue
if module_filename == "base.py" and skip_base_classes:
continue
module_name = module_filename.replace(".py", "")
mod = importlib.import_module(
f"garak.{category}.{module_name}"
) # import here will access all namespace level imports consider a cache to speed up processing
module_entries = set(_extract_modules_klasses(mod))
if skip_base_classes:
module_entries = module_entries.difference(base_plugin_classnames)
module_plugin_names = set()
for module_entry in module_entries:
obj = getattr(mod, module_entry)
for interface in base_plugin_classnames:
klass = getattr(base_mod, interface)
if issubclass(obj, klass):
module_plugin_names.add((module_entry, obj.active))

for module_plugin_name, active in sorted(module_plugin_names):
plugin_class_names.append(
(f"{category}.{module_name}.{module_plugin_name}", active)
)
return plugin_class_names
enum_entry = (k, v["active"])
plugin_class_names.add(enum_entry)

return sorted(plugin_class_names)


def load_plugin(path, break_on_fail=True, config_root=_config) -> object:
Expand Down
13 changes: 7 additions & 6 deletions garak/attempt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Defines the Attempt class, which encapsulates a prompt with metadata and results"""

from collections.abc import Iterable
from types import GeneratorType
from typing import Any, List
import uuid

Expand All @@ -13,8 +15,7 @@


class Attempt:
"""A class defining objects that represent everything that constitutes
a single attempt at evaluating an LLM.
"""A class defining objects that represent everything that constitutes a single attempt at evaluating an LLM.
:param status: The status of this attempt; ``ATTEMPT_NEW``, ``ATTEMPT_STARTED``, or ``ATTEMPT_COMPLETE``
:type status: int
Expand Down Expand Up @@ -169,8 +170,7 @@ def __getattribute__(self, name: str) -> Any:
return super().__getattribute__(name)

def __setattr__(self, name: str, value: Any) -> None:
"""override prompt and outputs access to take from history
NB. output elements need to be able to be None"""
"""override prompt and outputs access to take from history NB. output elements need to be able to be None"""

if name == "prompt":
if value is None:
Expand All @@ -179,8 +179,9 @@ def __setattr__(self, name: str, value: Any) -> None:
self._add_first_turn("user", value)

elif name == "outputs":
if not isinstance(value, list):
raise TypeError("Value for attempt.outputs must be a list")
if not (isinstance(value, list) or isinstance(value, GeneratorType)):
raise TypeError("Value for attempt.outputs must be a list or generator")
value = list(value)
if len(self.messages) == 0:
raise TypeError("A prompt must be set before outputs are given")
# do we have only the initial prompt? in which case, let's flesh out messages a bit
Expand Down
30 changes: 10 additions & 20 deletions garak/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,30 +176,20 @@ def print_buffs():

# describe plugin
def plugin_info(plugin_name):
import inspect
from garak._plugins import plugin_info

from garak._plugins import load_plugin

# load plugin
try:
plugin = load_plugin(plugin_name)
info = plugin_info(plugin_name)
if len(info) > 0:
print(f"Configured info on {plugin_name}:")
priority_fields = ["description"]
skip_fields = ["prompts", "triggers"]
# print the attribs it has
for v in priority_fields:
print(f"{v:>35}:", getattr(plugin, v))
for v in sorted(dir(plugin)):
if v in priority_fields or v in skip_fields:
continue
if v.startswith("_") or inspect.ismethod(getattr(plugin, v)):
for k in priority_fields:
if k in info:
print(f"{k:>35}:", info[k])
for k, v in info.items():
if k in priority_fields:
continue
print(f"{v:>35}:", getattr(plugin, v))

except ValueError as e:
print(e)
except Exception as e:
print(e)
print(f"{k:>35}:", v)
else:
print(
f"Plugin {plugin_name} not found. Try --list_probes, or --list_detectors."
)
Expand Down
Loading

0 comments on commit 637ffa1

Please sign in to comment.