Skip to content

Commit

Permalink
Fixed all the excess new lines
Browse files Browse the repository at this point in the history
  • Loading branch information
tanaygodse committed Sep 30, 2024
1 parent b7368eb commit 6e8c494
Showing 1 changed file with 7 additions and 16 deletions.
23 changes: 7 additions & 16 deletions ai_engine_sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ class Session:
_messages (List[ApiBaseMessage]): A list to store messages associated with the session.
_message_ids (set[str]): A set to store unique message IDs to prevent duplication.
"""

def __init__(self, api_base_url: str, api_key: str, session_id: str, function_group: str):
"""
Initializes a new session with the given parameters.
Expand Down Expand Up @@ -274,8 +273,7 @@ async def get_messages(self) -> List[ApiBaseMessage]:
agent_json: dict = message['agent_json']
agent_json_type: str = agent_json['type'].upper()
if is_task_selection_message(message_type=agent_json_type):
indexed_task_options: dict = get_indexed_task_options_from_raw_api_response(
raw_api_response=message)
indexed_task_options: dict = get_indexed_task_options_from_raw_api_response(raw_api_response=message)
newMessages.append(
TaskSelectionMessage.model_validate({
'type': agent_json_type,
Expand Down Expand Up @@ -363,17 +361,14 @@ async def execute_function(self, function_ids: list[str], objective: str, contex
})
)


class AiEngine:
def __init__(self, api_key: str, options: Optional[dict] = None):
self._api_base_url = options.get(
'api_base_url') if options and 'api_base_url' in options else default_api_base_url
self._api_base_url = options.get('api_base_url') if options and 'api_base_url' in options else default_api_base_url
self._api_key = api_key

####
# Function groups
####

async def get_function_groups(self) -> List[FunctionGroup]:
logger.debug("get_function_groups")
publicGroups, privateGroups = await asyncio.gather(
Expand Down Expand Up @@ -466,7 +461,6 @@ async def get_function_group_by_function(self, function_id: str):
###
# Functions
###

async def get_functions_by_function_group(self, function_group_id: str) -> list[FunctionGroupFunctions]:
raw_response: dict = await make_api_request(
api_base_url=self._api_base_url,
Expand All @@ -478,14 +472,14 @@ async def get_functions_by_function_group(self, function_group_id: str) -> list[
if "functions" in raw_response:
list(
map(
lambda function_name: FunctionGroupFunctions.model_validate(
{"name": function_name}),
lambda function_name: FunctionGroupFunctions.model_validate({"name": function_name}),
raw_response["functions"]
)
)

return result


async def get_functions(self) -> list[Function]:
raw_response: dict = await make_api_request(
api_base_url=self._api_base_url,
Expand All @@ -502,10 +496,8 @@ async def get_functions(self) -> list[Function]:
####
# Model
####

async def get_models(self) -> List[Model]:
pending_credits = [self.get_model_credits(
model_id) for model_id in DefaultModelIds]
pending_credits = [self.get_model_credits(model_id) for model_id in DefaultModelIds]

models = [Model(
id=model_id,
Expand Down Expand Up @@ -552,8 +544,7 @@ async def create_session(self, function_group: str, opts: Optional[dict] = None)
email=opts.get('email') if opts else "",
functionGroup=function_group,
preferencesEnabled=False,
requestModel=opts.get(
'model') if opts and 'model' in opts else DefaultModelId
requestModel=opts.get('model') if opts and 'model' in opts else DefaultModelId
)
response = await make_api_request(
api_base_url=self._api_base_url,
Expand Down Expand Up @@ -586,4 +577,4 @@ async def share_function_group(
payload=payload
)
logger.debug(f"FG successfully shared: {function_group_id} with {target_user_email}")
return raw_response
return raw_response

0 comments on commit 6e8c494

Please sign in to comment.