-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
integrations compatible with agent #57
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,170 +1,65 @@ | ||
import os | ||
import logging | ||
import pprint | ||
|
||
import numpy as np | ||
from addresses import ADDRESSES | ||
from dotenv import find_dotenv, load_dotenv | ||
from lp_tools import get_tick_range | ||
from mint_position import close_position, get_all_user_positions, get_mint_params | ||
from prefect import get_run_logger | ||
|
||
from giza.agents.action import action | ||
from giza.agents import AgentResult, GizaAgent | ||
from giza.agents.task import task | ||
from giza.agents import GizaAgent | ||
|
||
load_dotenv(find_dotenv()) | ||
|
||
# Here we load a custom sepolia rpc url from the environment | ||
sepolia_rpc_url = os.environ.get("SEPOLIA_RPC_URL") | ||
|
||
MODEL_ID = ... # Update with your model ID | ||
VERSION_ID = ... # Update with your version ID | ||
|
||
|
||
@task | ||
def process_data(realized_vol, dec_price_change): | ||
pct_change_sq = (100 * dec_price_change) ** 2 | ||
X = np.array([[realized_vol, pct_change_sq]]) | ||
return X | ||
|
||
|
||
# Get image | ||
@task | ||
def get_data(): | ||
# TODO: implement fetching onchain or from some other source | ||
realized_vol = 4.20 | ||
dec_price_change = 0.1 | ||
return realized_vol, dec_price_change | ||
|
||
|
||
@task | ||
def create_agent( | ||
model_id: int, version_id: int, chain: str, contracts: dict, account: str | ||
): | ||
""" | ||
Create a Giza agent for the volatility prediction model | ||
""" | ||
def transmission(): | ||
logger = logging.getLogger(__name__) | ||
id = ... | ||
version = ... | ||
account = ... | ||
realized_vol, dec_price_change = get_data() | ||
input_data = process_data(realized_vol, dec_price_change) | ||
|
||
agent = GizaAgent( | ||
contracts=contracts, | ||
id=model_id, | ||
version_id=version_id, | ||
chain=chain, | ||
integrations=["UniswapV3"], | ||
id=id, | ||
chain="ethereum:sepolia:https://sepolia.infura.io/v3/765888cfa824440c8c0feb5b492abedd", | ||
version_id=version, | ||
account=account, | ||
) | ||
return agent | ||
|
||
|
||
@task | ||
def predict(agent: GizaAgent, X: np.ndarray): | ||
""" | ||
Predict the digit in an image. | ||
|
||
Args: | ||
image (np.ndarray): Image to predict. | ||
|
||
Returns: | ||
int: Predicted digit. | ||
""" | ||
prediction = agent.predict(input_feed={"val": X}, verifiable=True, job_size="XL") | ||
return prediction | ||
|
||
|
||
@task | ||
def get_pred_val(prediction: AgentResult): | ||
""" | ||
Get the value from the prediction. | ||
|
||
Args: | ||
prediction (dict): Prediction from the model. | ||
|
||
Returns: | ||
int: Predicted value. | ||
""" | ||
# This will block the executon until the prediction has generated the proof and the proof has been verified | ||
return prediction.value[0][0] | ||
|
||
|
||
# Create Action | ||
@action | ||
def transmission( | ||
pred_model_id, | ||
pred_version_id, | ||
account="dev", | ||
chain=f"ethereum:sepolia:{sepolia_rpc_url}", | ||
): | ||
logger = get_run_logger() | ||
|
||
nft_manager_address = ADDRESSES["NonfungiblePositionManager"][11155111] | ||
tokenA_address = ADDRESSES["UNI"][11155111] | ||
tokenB_address = ADDRESSES["WETH"][11155111] | ||
pool_address = "0x287B0e934ed0439E2a7b1d5F0FC25eA2c24b64f7" | ||
user_address = "0xCBB090699E0664f0F6A4EFbC616f402233718152" | ||
|
||
pool_fee = 3000 | ||
tokenA_amount = 1000 | ||
tokenB_amount = 1000 | ||
|
||
logger.info("Fetching input data") | ||
realized_vol, dec_price_change = get_data() | ||
|
||
logger.info(f"Input data: {realized_vol}, {dec_price_change}") | ||
X = process_data(realized_vol, dec_price_change) | ||
|
||
nft_manager_abi_path = "nft_manager_abi.json" | ||
contracts = { | ||
"nft_manager": [nft_manager_address, nft_manager_abi_path], | ||
"tokenA": [tokenA_address], | ||
"tokenB": tokenB_address, | ||
"pool": pool_address, | ||
} | ||
agent = create_agent( | ||
model_id=pred_model_id, | ||
version_id=pred_version_id, | ||
chain=chain, | ||
contracts=contracts, | ||
account=account, | ||
result = agent.predict( | ||
input_feed={"val": input_data}, verifiable=True, dry_run=True | ||
) | ||
result = predict(agent, X) | ||
predicted_value = get_pred_val(result) | ||
logger.info(f"Result: {result}") | ||
|
||
logger.info(f"Result: {result}") | ||
with agent.execute() as contracts: | ||
logger.info("Executing contract") | ||
# TODO: fix below | ||
positions = get_all_user_positions(contracts.nft_manager, user_address) | ||
logger.info(f"Found the following positions: {positions}") | ||
# step 4: close all positions | ||
logger.info("Closing all open positions...") | ||
for nft_id in positions: | ||
close_position(user_address, contracts.nft_manager, nft_id) | ||
# step 4: calculate mint params | ||
logger.info("Calculating mint params...") | ||
_, curr_tick, _, _, _, _, _ = contracts.pool.slot0() | ||
tokenA_decimals = contracts.tokenA.decimals() | ||
tokenB_decimals = contracts.tokenB.decimals() | ||
# TODO: confirm input should be result and not result * 100 | ||
lower_tick, upper_tick = get_tick_range( | ||
curr_tick, predicted_value, tokenA_decimals, tokenB_decimals, pool_fee | ||
UNI_address = "0x1f9840a85d5aF5bf1D1762F925BDADdC4201F984" | ||
WETH_address = "0xfFf9976782d46CC05630D1f6eBAb18b2324d6B14" | ||
uni = contracts.UniswapV3 | ||
volatility_prediction = result.value[0] | ||
pool = uni.get_pool(UNI_address, WETH_address, fee=500) | ||
curr_price = pool.get_pool_price() | ||
lower_price = curr_price * (1 - volatility_prediction) | ||
upper_price = curr_price * (1 + volatility_prediction) | ||
amount0 = 100 | ||
amount1 = 100 | ||
agent_result = uni.mint_position( | ||
pool, lower_price, upper_price, amount0, amount1 | ||
) | ||
mint_params = get_mint_params( | ||
tokenA_address, | ||
tokenB_address, | ||
user_address, | ||
tokenA_amount, | ||
tokenB_amount, | ||
pool_fee, | ||
lower_tick, | ||
upper_tick, | ||
logger.info( | ||
f"Current price: {curr_price}, new bounds: {lower_price}, {upper_price}" | ||
) | ||
# step 5: mint new position | ||
logger.info("Minting new position...") | ||
contract_result = contracts.nft_manager.mint(mint_params) | ||
logger.info("SUCCESSFULLY MINTED A POSITION") | ||
logger.info("Contract executed") | ||
logger.info(f"Minted position: {agent_result}") | ||
|
||
logger.info(f"Contract result: {contract_result}") | ||
pprint.pprint(contract_result.__dict__) | ||
logger.info(f"Contract result: {agent_result}") | ||
logger.info("Finished") | ||
|
||
|
||
transmission(MODEL_ID, VERSION_ID) | ||
transmission() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,7 @@ | |
from typing import Any, Callable, Dict, List, Optional, Self, Tuple, Union | ||
|
||
from ape import Contract, accounts, networks | ||
from ape.api import AccountAPI | ||
from ape.contracts import ContractInstance | ||
from ape.exceptions import NetworkError | ||
from ape_accounts.accounts import InvalidPasswordError | ||
|
@@ -18,6 +19,7 @@ | |
from giza.cli.utils.enums import JobKind, JobStatus | ||
from requests import HTTPError | ||
|
||
from giza.agents.integration import IntegrationFactory | ||
from giza.agents.model import GizaModel | ||
from giza.agents.utils import read_json | ||
|
||
|
@@ -34,7 +36,8 @@ def __init__( | |
self, | ||
id: int, | ||
version_id: int, | ||
contracts: Dict[str, Union[str, List[str]]], | ||
contracts: Optional[Dict[str, Union[str, List[str]]]] = None, | ||
integrations: Optional[List[str]] = None, | ||
chain: Optional[str] = None, | ||
account: Optional[str] = None, | ||
**kwargs: Any, | ||
|
@@ -44,6 +47,7 @@ def __init__( | |
model_id (int): The ID of the model. | ||
version_id (int): The version of the model. | ||
contracts (Dict[str, str]): The contracts to handle, must be a dictionary with the contract name as the key and the contract address as the value. | ||
integrations (List[str]): The integrations to use. | ||
chain_id (int): The ID of the blockchain network. | ||
**kwargs: Additional keyword arguments. | ||
""" | ||
|
@@ -63,11 +67,11 @@ def __init__( | |
logger.error("Agent is missing required parameters") | ||
raise ValueError(f"Agent is missing required parameters: {e}") | ||
|
||
self.contract_handler = ContractHandler(contracts) | ||
self.chain = chain | ||
self.account = account | ||
self._check_passphrase_in_env() | ||
self._check_or_create_account() | ||
self.contract_handler = ContractHandler(contracts, integrations) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add an |
||
|
||
# Useful for testing | ||
network_parser: Callable = kwargs.get( | ||
|
@@ -240,8 +244,8 @@ def execute(self) -> Any: | |
f"Invalid passphrase for account {self.account}. Could not decrypt account." | ||
) from e | ||
logger.debug("Autosign enabled") | ||
with accounts.use_sender(self._account): | ||
yield self.contract_handler.handle() | ||
with accounts.use_sender(self._account) as sender: | ||
yield self.contract_handler.handle(account=sender) | ||
|
||
def predict( | ||
self, | ||
|
@@ -452,15 +456,35 @@ class ContractHandler: | |
which means that it should be done insede the GizaAgent's execute context. | ||
""" | ||
|
||
def __init__(self, contracts: Dict[str, Union[str, List[str]]]) -> None: | ||
def __init__( | ||
self, | ||
contracts: Optional[Dict[str, Union[str, List[str]]]] = None, | ||
integrations: Optional[List[str]] = None, | ||
) -> None: | ||
if contracts is None: | ||
contracts = {} | ||
if integrations is None: | ||
integrations = [] | ||
contract_names = list(contracts.keys()) | ||
duplicates = set(contract_names) & set(integrations) | ||
if duplicates: | ||
duplicate_names = ", ".join(duplicates) | ||
raise ValueError( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For future reference, lets create a custom exception for this |
||
f"Integrations of these names already exist: {duplicate_names}. Choose different contract names." | ||
) | ||
self._contracts = contracts | ||
self._integrations = integrations | ||
self._contracts_instances: Dict[str, ContractInstance] = {} | ||
self._integrations_instances: Dict[str, IntegrationFactory] = {} | ||
|
||
def __getattr__(self, name: str) -> ContractInstance: | ||
def __getattr__(self, name: str) -> Union[ContractInstance, IntegrationFactory]: | ||
""" | ||
Get the contract by name. | ||
""" | ||
return self._contracts_instances[name] | ||
if name in self._contracts_instances.keys(): | ||
return self._contracts_instances[name] | ||
if name in self._integrations_instances.keys(): | ||
return self._integrations_instances[name] | ||
|
||
def _initiate_contract( | ||
self, address: str, abi: Optional[str] = None | ||
|
@@ -472,26 +496,41 @@ def _initiate_contract( | |
return Contract(address=address) | ||
return Contract(address=address, abi=abi) | ||
|
||
def handle(self) -> Self: | ||
def _initiate_integration( | ||
self, name: str, account: AccountAPI | ||
) -> IntegrationFactory: | ||
""" | ||
Initiate the integration. | ||
""" | ||
return IntegrationFactory.from_name(name, sender=account) | ||
|
||
def handle(self, account: Optional[AccountAPI] = None) -> Self: | ||
""" | ||
Handle the contracts. | ||
""" | ||
try: | ||
for name, contract_data in self._contracts.items(): | ||
if isinstance(contract_data, str): | ||
address = contract_data | ||
self._contracts_instances[name] = self._initiate_contract(address) | ||
elif isinstance(contract_data, list): | ||
if len(contract_data) == 1: | ||
address = contract_data[0] | ||
if self._contracts: | ||
for name, contract_data in self._contracts.items(): | ||
if isinstance(contract_data, str): | ||
address = contract_data | ||
self._contracts_instances[name] = self._initiate_contract( | ||
address | ||
) | ||
else: | ||
address, abi = contract_data | ||
self._contracts_instances[name] = self._initiate_contract( | ||
address, abi | ||
) | ||
elif isinstance(contract_data, list): | ||
if len(contract_data) == 1: | ||
address = contract_data[0] | ||
self._contracts_instances[name] = self._initiate_contract( | ||
address | ||
) | ||
else: | ||
address, abi = contract_data | ||
self._contracts_instances[name] = self._initiate_contract( | ||
address, abi | ||
) | ||
for name in self._integrations: | ||
self._integrations_instances[name] = self._initiate_integration( | ||
name, account | ||
) | ||
except NetworkError as e: | ||
logger.error(f"Failed to initiate contract: {e}") | ||
raise ValueError( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from ape.api import AccountAPI | ||
|
||
import giza.agents.integrations.uniswap.uniswap as uniswap_module | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Create an
from import giza.agents.integrations.uniswap.uniswap import Uniswap
__all__ = ["Uniswap"] New import: from giza.agents.integrations import Uniswap Another suggestion could be to call all the integration where the main integration is with a common name like from giza.agents.integrations.uniswap.core import UniswapV3
from giza.agents.integrations.enzyme.core import Enzyme |
||
|
||
|
||
class IntegrationFactory: | ||
@staticmethod | ||
def from_name(name: str, sender: AccountAPI) -> uniswap_module.Uniswap: | ||
match name: | ||
case "UniswapV3": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that it would be better to have an enum for this, maybe open an issue and will deal with this later |
||
return uniswap_module.Uniswap(sender, version=3) | ||
case _: | ||
raise ValueError(f"Integration {name} not found") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove the URL as it might be a private RPC node