Skip to content

Commit

Permalink
feat: add __repr__ to engines
Browse files Browse the repository at this point in the history
  • Loading branch information
zhudotexe committed Jul 1, 2024
1 parent 1236745 commit f38bd40
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 1 deletion.
1 change: 0 additions & 1 deletion kani/engines/anthropic/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ def __init__(
self.hyperparams = hyperparams

# token counting - claude 3+ does not release tokenizer so we have to do heuristics and cache
self.token_cache = {}
if model.startswith("claude-2"):
# anthropic async client loads a json file using anyio for some reason; hook into the underlying loader
# noinspection PyProtectedMember
Expand Down
15 changes: 15 additions & 0 deletions kani/engines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,18 @@ async def close(self):
"""Optional: Clean up any resources the engine might need."""
pass

# ==== internal ====
__ignored_repr_attrs__ = ("token_cache",)

def __repr__(self):
"""Default: generate a repr based on the instance's __dict__."""
attrs = ", ".join(
f"{name}={value!r}"
for name, value in self.__dict__.items()
if name not in self.__ignored_repr_attrs__ and not name.startswith("_")
)
return f"{type(self).__name__}({attrs})"


# ==== utils ====
class WrapperEngine(BaseEngine):
Expand Down Expand Up @@ -174,6 +186,9 @@ def function_token_reserve(self, functions: list[AIFunction]) -> int:
async def close(self):
return await self.engine.close()

def __repr__(self):
return f"{type(self).__name__}(engine={self.engine!r})"

# all other attributes are caught by this default passthrough handler
def __getattr__(self, item):
return getattr(self.engine, item)
6 changes: 6 additions & 0 deletions kani/engines/openai/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,3 +231,9 @@ def _function_token_reserve_impl(self, functions):

async def close(self):
await self.client.close()

def __repr__(self):
return (
f"{type(self).__name__}(model={self.model}, max_context_size={self.max_context_size},"
f" hyperparams={self.hyperparams})"
)

0 comments on commit f38bd40

Please sign in to comment.