Skip to content

Commit

Permalink
Apply ApiClient permissions to everything (#1057)
Browse files Browse the repository at this point in the history
* Apply ApiClient to everything for proper perms

* Update Agents

* Updates

* lint
  • Loading branch information
Josh-XT authored Oct 19, 2023
1 parent 5ab2cd7 commit 8ad3770
Show file tree
Hide file tree
Showing 23 changed files with 303 additions and 156 deletions.
5 changes: 5 additions & 0 deletions agixt/ApiClient.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,8 @@ def verify_api_key(authorization: str = Header(None)):
raise HTTPException(status_code=401, detail="Invalid API Key")
else:
return "USER"


def get_api_client(authorization: str = Header(None)):
scheme, _, api_key = authorization.partition(" ")
return AGiXTSDK(base_uri="http://localhost:7437", api_key=api_key)
13 changes: 7 additions & 6 deletions agixt/Chains.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import logging
from ApiClient import ApiClient, Chain, Prompts, log_interaction
from ApiClient import Chain, Prompts, log_interaction
from Extensions import Extensions


class Chains:
def __init__(self, user="USER"):
def __init__(self, user="USER", ApiClient=None):
self.user = user
self.chain = Chain(user=user)
self.ApiClient = ApiClient

async def run_chain_step(
self,
Expand Down Expand Up @@ -44,14 +45,14 @@ async def run_chain_step(
if "conversation" in args:
args["conversation_name"] = args["conversation"]
if prompt_type == "Command":
return ApiClient.execute_command(
return self.ApiClient.execute_command(
agent_name=agent_name,
command_name=step["prompt"]["command_name"],
command_args=args,
conversation_name=args["conversation_name"],
)
elif prompt_type == "Prompt":
result = ApiClient.prompt_agent(
result = self.ApiClient.prompt_agent(
agent_name=agent_name,
prompt_name=prompt_name,
prompt_args={
Expand All @@ -62,7 +63,7 @@ async def run_chain_step(
},
)
elif prompt_type == "Chain":
result = ApiClient.run_chain(
result = self.ApiClient.run_chain(
chain_name=args["chain"],
user_input=args["input"],
agent_name=agent_name,
Expand Down Expand Up @@ -94,7 +95,7 @@ async def run_chain(
from_step=1,
chain_args={},
):
chain_data = ApiClient.get_chain(chain_name=chain_name)
chain_data = self.ApiClient.get_chain(chain_name=chain_name)
if chain_data == {}:
return f"Chain `{chain_name}` not found."
log_interaction(
Expand Down
12 changes: 4 additions & 8 deletions agixt/Extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,13 @@ def __init__(
agent_name="",
agent_config=None,
conversation_name="",
load_commands_flag: bool = True,
ApiClient=None,
):
self.agent_config = agent_config
self.agent_name = agent_name if agent_name else "gpt4free"
self.conversation_name = conversation_name
if load_commands_flag:
self.commands = self.load_commands()
else:
self.commands = []
self.ApiClient = ApiClient
self.commands = self.load_commands()
if agent_config != None:
if "commands" not in self.agent_config:
self.agent_config["commands"] = {}
Expand Down Expand Up @@ -146,14 +144,12 @@ def get_commands_list(self):
return commands_list

async def execute_command(self, command_name: str, command_args: dict = None):
print("Running Command (Extensions.py)")
print(command_name)
print(command_args)
injection_variables = {
"agent_name": self.agent_name,
"command_name": command_name,
"conversation_name": self.conversation_name,
"enabled_commands": self.get_enabled_commands(),
"ApiClient": self.ApiClient,
**self.agent_config["settings"],
}
command_function, module, params = self.find_command(command_name=command_name)
Expand Down
33 changes: 24 additions & 9 deletions agixt/Interactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from Websearch import Websearch
from Extensions import Extensions
from ApiClient import (
ApiClient,
Agent,
Prompts,
Chain,
Expand All @@ -27,13 +26,20 @@ def get_tokens(text: str) -> int:


class Interactions:
def __init__(self, agent_name: str = "", collection_number: int = 0, user="USER"):
def __init__(
self,
agent_name: str = "",
collection_number: int = 0,
user="USER",
ApiClient=None,
):
if agent_name != "":
self.agent_name = agent_name
self.agent = Agent(self.agent_name, user=user)
self.agent = Agent(self.agent_name, user=user, ApiClient=ApiClient)
self.agent_commands = self.agent.get_commands_string()
self.websearch = Websearch(
agent_name=self.agent_name,
ApiClient=ApiClient,
searxng_instance_url=self.agent.AGENT_CONFIG["settings"][
"SEARXNG_INSTANCE_URL"
]
Expand All @@ -48,13 +54,15 @@ def __init__(self, agent_name: str = "", collection_number: int = 0, user="USER"
agent_name=self.agent_name,
agent_config=self.agent.AGENT_CONFIG,
collection_number=int(collection_number),
ApiClient=ApiClient,
)
self.stop_running_event = None
self.browsed_links = []
self.failures = 0
self.user = user
self.chain = Chain(user=user)
self.cp = Prompts(user=user)
self.ApiClient = ApiClient

def custom_format(self, string, **kwargs):
if isinstance(string, list):
Expand Down Expand Up @@ -118,6 +126,7 @@ async def format_prompt(
agent_name=self.agent_name,
agent_config=self.agent.AGENT_CONFIG,
collection_number=2,
ApiClient=self.ApiClient,
).get_memories(
user_input=user_input,
limit=3,
Expand All @@ -127,6 +136,7 @@ async def format_prompt(
agent_name=self.agent_name,
agent_config=self.agent.AGENT_CONFIG,
collection_number=3,
ApiClient=self.ApiClient,
).get_memories(
user_input=user_input,
limit=3,
Expand All @@ -143,6 +153,7 @@ async def format_prompt(
agent_name=self.agent_name,
agent_config=self.agent.AGENT_CONFIG,
collection_number=1,
ApiClient=self.ApiClient,
).get_memories(
user_input=user_input,
limit=top_results,
Expand All @@ -156,6 +167,7 @@ async def format_prompt(
collection_number=int(
kwargs["inject_memories_from_collection_number"]
),
ApiClient=self.ApiClient,
).get_memories(
user_input=user_input,
limit=top_results,
Expand Down Expand Up @@ -460,7 +472,7 @@ async def run(
time.sleep(10)
if context_results > 0:
context_results = context_results - 1
return ApiClient.prompt_agent(
return self.ApiClient.prompt_agent(
agent_name=self.agent_name,
prompt_name=prompt,
prompt_args={
Expand Down Expand Up @@ -500,7 +512,7 @@ async def run(
if shots > 1:
responses = [self.response]
for shot in range(shots - 1):
shot_response = ApiClient.prompt_agent(
shot_response = self.ApiClient.prompt_agent(
agent_name=self.agent_name,
prompt_name=prompt,
prompt_args={
Expand All @@ -525,16 +537,19 @@ async def run(
return self.response

def create_command_suggestion_chain(self, agent_name, command_name, command_args):
chains = ApiClient.get_chains()
chains = self.ApiClient.get_chains()
chain_name = f"{agent_name} Command Suggestions"
if chain_name in chains:
step = (
int(ApiClient.get_chain(chain_name=chain_name)["steps"][-1]["step"]) + 1
int(
self.ApiClient.get_chain(chain_name=chain_name)["steps"][-1]["step"]
)
+ 1
)
else:
ApiClient.add_chain(chain_name=chain_name)
self.ApiClient.add_chain(chain_name=chain_name)
step = 1
ApiClient.add_step(
self.ApiClient.add_step(
chain_name=chain_name,
agent_name=agent_name,
step_number=step,
Expand Down
7 changes: 5 additions & 2 deletions agixt/Memories.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from datetime import datetime
from collections import Counter
from typing import List
from ApiClient import ApiClient


if sys.platform == "win32":
Expand Down Expand Up @@ -100,7 +99,11 @@ def query_results_to_records(results: "QueryResult"):

class Memories:
def __init__(
self, agent_name: str = "AGiXT", agent_config=None, collection_number: int = 0
self,
agent_name: str = "AGiXT",
agent_config=None,
collection_number: int = 0,
ApiClient=None,
):
self.agent_name = agent_name
self.collection_name = camel_to_snake(agent_name)
Expand Down
3 changes: 2 additions & 1 deletion agixt/Providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,11 @@ def get_providers_with_settings():


class Providers:
def __init__(self, name, **kwargs):
def __init__(self, name, ApiClient, **kwargs):
if name in DISABLED_PROVIDERS:
raise AttributeError(f"module {__name__} has no attribute {name}")
try:
kwargs["ApiClient"] = ApiClient
module = importlib.import_module(f"providers.{name}")
provider_class = getattr(module, f"{name.capitalize()}Provider")
self.instance = provider_class(**kwargs)
Expand Down
17 changes: 9 additions & 8 deletions agixt/Websearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,20 @@
from playwright.async_api import async_playwright
from bs4 import BeautifulSoup
from typing import List
from ApiClient import ApiClient


class Websearch:
def __init__(
self,
agent_name: str = "AGiXT",
searxng_instance_url: str = "",
ApiClient=None,
**kwargs,
):
self.agent_name = agent_name
self.searx_instance_url = searxng_instance_url
self.agent_config = ApiClient.get_agentconfig(agent_name=self.agent_name)
self.ApiClient = ApiClient
self.agent_config = self.ApiClient.get_agentconfig(agent_name=self.agent_name)
self.agent_settings = self.agent_config["settings"]
self.requirements = ["agixtsdk"]
self.failures = []
Expand Down Expand Up @@ -52,7 +53,7 @@ async def get_web_content(self, url):
soup = BeautifulSoup(content, "html.parser")
text_content = soup.get_text()
text_content = " ".join(text_content.split())
ApiClient.learn_url(
self.ApiClient.learn_url(
agent_name=self.agent_name, url=url, collection_number=1
)
self.browsed_links.append(url)
Expand Down Expand Up @@ -115,7 +116,7 @@ async def resursive_browsing(self, user_input, links):
if len(link_list) > 5:
link_list = link_list[:3]
try:
pick_a_link = ApiClient.prompt_agent(
pick_a_link = self.ApiClient.prompt_agent(
agent_name=self.agent_name,
prompt_name="Pick-a-Link",
prompt_args={
Expand Down Expand Up @@ -179,12 +180,12 @@ async def search(self, query: str) -> List[str]:
except: # Select default remote server that typically works if unable to get list.
self.searx_instance_url = "https://search.us.projectsegfau.lt"
self.agent_settings["SEARXNG_INSTANCE_URL"] = self.searx_instance_url
ApiClient.update_agent_settings(
self.ApiClient.update_agent_settings(
agent_name=self.agent_name, settings=self.agent_settings
)
server = self.searx_instance_url.rstrip("/")
self.agent_settings["SEARXNG_INSTANCE_URL"] = server
ApiClient.update_agent_settings(
self.ApiClient.update_agent_settings(
agent_name=self.agent_name, settings=self.agent_settings
)
endpoint = f"{server}/search"
Expand Down Expand Up @@ -213,7 +214,7 @@ async def search(self, query: str) -> List[str]:
if len(self.failures) > 5:
logging.info("Failed 5 times. Trying DDG...")
self.agent_settings["SEARXNG_INSTANCE_URL"] = ""
ApiClient.update_agent_settings(
self.ApiClient.update_agent_settings(
agent_name=self.agent_name, settings=self.agent_settings
)
return await self.ddg_search(query=query)
Expand Down Expand Up @@ -264,7 +265,7 @@ async def websearch_agent(
except:
websearch_timeout = 0
if websearch_depth > 0:
search_string = ApiClient.prompt_agent(
search_string = self.ApiClient.prompt_agent(
agent_name=self.agent_name,
prompt_name="WebSearch",
prompt_args={
Expand Down
6 changes: 4 additions & 2 deletions agixt/db/Agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def import_agent_config(agent_name, user="USER"):


class Agent:
def __init__(self, agent_name=None, user="USER"):
def __init__(self, agent_name=None, user="USER", ApiClient=None):
self.agent_name = agent_name if agent_name is not None else "AGiXT"
self.AGENT_CONFIG = self.get_agent_config()
self.load_config_keys()
Expand All @@ -212,7 +212,9 @@ def __init__(self, agent_name=None, user="USER"):
if setting not in self.PROVIDER_SETTINGS:
self.PROVIDER_SETTINGS[setting] = DEFAULT_SETTINGS[setting]
self.AI_PROVIDER = self.AGENT_CONFIG["settings"]["provider"]
self.PROVIDER = Providers(self.AI_PROVIDER, **self.PROVIDER_SETTINGS)
self.PROVIDER = Providers(
name=self.AI_PROVIDER, ApiClient=ApiClient, **self.PROVIDER_SETTINGS
)
self.available_commands = Extensions(
agent_name=self.agent_name, agent_config=self.AGENT_CONFIG
).get_available_commands()
Expand Down
Loading

0 comments on commit 8ad3770

Please sign in to comment.