Skip to content

Commit

Permalink
contract caller tests added and passing
Browse files Browse the repository at this point in the history
  • Loading branch information
pacrob committed Feb 8, 2023
1 parent 4872693 commit 5b4c87f
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 35 deletions.
87 changes: 80 additions & 7 deletions tests/core/contracts/test_contract_caller_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def test_caller_with_args_and_no_transaction_keyword(

add_result = contract.add(3, 5)
assert add_result == 8


@pytest.mark.parametrize(
"method_input, expected, type_str, namedtuple_repr",
Expand Down Expand Up @@ -207,12 +207,13 @@ def test_caller_with_args_and_no_transaction_keyword(
),
)
def test_tuple_contract_caller_with_decode_tuples(
tuple_contract_with_decode_tuples, method_input, expected, type_str, namedtuple_repr, transaction_dict
tuple_contract_with_decode_tuples,
method_input,
expected,
type_str,
namedtuple_repr,
):
# result = tuple_contract_with_decode_tuples.caller(transaction=transaction_dict).method(method_input)
result = tuple_contract_with_decode_tuples.caller(decode_tuples=True).method(method_input)
caller = tuple_contract_with_decode_tuples.caller
breakpoint()
result = tuple_contract_with_decode_tuples.caller.method(method_input)
assert result == expected
assert str(type(result)) == type_str
assert result.__repr__() == namedtuple_repr
Expand All @@ -221,7 +222,6 @@ def test_tuple_contract_caller_with_decode_tuples(
assert str(type(result)) == type_str
assert result.__repr__() == namedtuple_repr


# --- async --- #


Expand Down Expand Up @@ -369,3 +369,76 @@ async def test_async_caller_with_args_and_no_transaction_keyword(

add_result = await contract.add(3, 5)
assert add_result == 8


@pytest.mark.parametrize(
"method_input, expected, type_str, namedtuple_repr",
(
(
{
"a": 123,
"b": [1, 2],
"c": [
{
"x": 234,
"y": [True, False],
"z": [
"0x4AD7E79d88650B01EEA2B1f069f01EE9db343d5c",
"0xfdF1946A9b40245224488F1a36f4A9ed4844a523",
"0xfdF1946A9b40245224488F1a36f4A9ed4844a523",
],
},
{
"x": 345,
"y": [False, False],
"z": [
"0xefd1FF70c185A1C0b125939815225199079096Ee",
"0xf35C0784794F3Cd935F5754d3a0EbcE95bEf851e",
],
},
],
},
(
123,
[1, 2],
[
(
234,
[True, False],
[
"0x4AD7E79d88650B01EEA2B1f069f01EE9db343d5c",
"0xfdF1946A9b40245224488F1a36f4A9ed4844a523",
"0xfdF1946A9b40245224488F1a36f4A9ed4844a523",
],
),
(
345,
[False, False],
[
"0xefd1FF70c185A1C0b125939815225199079096Ee",
"0xf35C0784794F3Cd935F5754d3a0EbcE95bEf851e",
],
),
],
),
"<class 'web3._utils.abi.abi_decoded_namedtuple_factory.<locals>.ABIDecodedNamedTuple'>", # noqa: E501
"ABIDecodedNamedTuple(a=123, b=[1, 2], c=[ABIDecodedNamedTuple(x=234, y=[True, False], z=['0x4AD7E79d88650B01EEA2B1f069f01EE9db343d5c', '0xfdF1946A9b40245224488F1a36f4A9ed4844a523', '0xfdF1946A9b40245224488F1a36f4A9ed4844a523']), ABIDecodedNamedTuple(x=345, y=[False, False], z=['0xefd1FF70c185A1C0b125939815225199079096Ee', '0xf35C0784794F3Cd935F5754d3a0EbcE95bEf851e'])])", # noqa: E501
),
),
)
@pytest.mark.asyncio
async def test_async_tuple_contract_caller_with_decode_tuples(
async_tuple_contract_with_decode_tuples,
method_input,
expected,
type_str,
namedtuple_repr,
):
result = await async_tuple_contract_with_decode_tuples.caller.method(method_input)
assert result == expected
assert str(type(result)) == type_str
assert result.__repr__() == namedtuple_repr
result = await async_tuple_contract_with_decode_tuples.caller().method(method_input)
assert result == expected
assert str(type(result)) == type_str
assert result.__repr__() == namedtuple_repr
8 changes: 2 additions & 6 deletions tests/core/contracts/test_contract_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,7 @@ def test_initial_greeting(foo_contract):

def test_can_update_greeting(w3, foo_contract):
# send transaction that updates the greeting
tx_hash = foo_contract.functions.setBar(
"testing contracts is easy",
).transact(
tx_hash = foo_contract.functions.setBar("testing contracts is easy",).transact(
{
"from": w3.eth.accounts[1],
}
Expand All @@ -99,9 +97,7 @@ def test_can_update_greeting(w3, foo_contract):

def test_updating_greeting_emits_event(w3, foo_contract):
# send transaction that updates the greeting
tx_hash = foo_contract.functions.setBar(
"testing contracts is easy",
).transact(
tx_hash = foo_contract.functions.setBar("testing contracts is easy",).transact(
{
"from": w3.eth.accounts[1],
}
Expand Down
17 changes: 12 additions & 5 deletions web3/contract/async_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,11 @@ def __init__(self, address: Optional[ChecksumAddress] = None) -> None:
"The address argument is required to instantiate a contract."
)
self.functions = AsyncContractFunctions(
self.abi, self.w3, self.address, self.decode_tuples
self.abi, self.w3, self.address, decode_tuples=self.decode_tuples
)
self.caller = AsyncContractCaller(
self.abi, self.w3, self.address, decode_tuples=self.decode_tuples
)
self.caller = AsyncContractCaller(self.abi, self.w3, self.address)
self.events = AsyncContractEvents(self.abi, self.w3, self.address)
self.fallback = AsyncContract.get_fallback_function(
self.abi, self.w3, AsyncContractFunction, self.address
Expand Down Expand Up @@ -151,10 +153,13 @@ def factory(
),
)
contract.functions = AsyncContractFunctions(
contract.abi, contract.w3, contract.decode_tuples
contract.abi, contract.w3, decode_tuples=contract.decode_tuples
)
contract.caller = AsyncContractCaller(
contract.abi, contract.w3, contract.address
contract.abi,
contract.w3,
contract.address,
decode_tuples=contract.decode_tuples,
)
contract.events = AsyncContractEvents(contract.abi, contract.w3)
contract.fallback = AsyncContract.get_fallback_function(
Expand Down Expand Up @@ -254,6 +259,7 @@ async def call(
block_identifier: BlockIdentifier = "latest",
state_override: Optional[CallOverride] = None,
ccip_read_enabled: Optional[bool] = None,
decode_tuples: Optional[bool] = None,
) -> Any:
"""
Execute a contract function call using the `eth_call` interface.
Expand Down Expand Up @@ -294,7 +300,7 @@ async def call(
self.abi,
state_override,
ccip_read_enabled,
self.decode_tuples,
decode_tuples,
*self.args,
**self.kwargs,
)
Expand Down Expand Up @@ -502,4 +508,5 @@ def __call__(
transaction=transaction,
block_identifier=block_identifier,
ccip_read_enabled=ccip_read_enabled,
decode_tuples=self.decode_tuples,
)
6 changes: 5 additions & 1 deletion web3/contract/base_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,7 @@ def call(
block_identifier: BlockIdentifier = "latest",
state_override: Optional[CallOverride] = None,
ccip_read_enabled: Optional[bool] = None,
decode_tuples: Optional[bool] = False,
) -> Any:
# This was needed for typing
raise NotImplementedError(
Expand Down Expand Up @@ -1092,7 +1093,6 @@ def __init__(
ccip_read_enabled: Optional[bool] = None,
decode_tuples: Optional[bool] = False,
) -> None:
print(decode_tuples)
self.w3 = w3
self.address = address
self.abi = abi
Expand All @@ -1111,6 +1111,7 @@ def __init__(
contract_abi=self.abi,
address=self.address,
function_identifier=func["name"],
decode_tuples=decode_tuples,
)

block_id = parse_block_identifier(self.w3, block_identifier)
Expand All @@ -1120,6 +1121,7 @@ def __init__(
transaction=transaction,
block_identifier=block_id,
ccip_read_enabled=ccip_read_enabled,
decode_tuples=decode_tuples,
)

setattr(self, func["name"], caller_method)
Expand Down Expand Up @@ -1158,6 +1160,7 @@ def call_function(
transaction: Optional[TxParams] = None,
block_identifier: BlockIdentifier = "latest",
ccip_read_enabled: Optional[bool] = None,
decode_tuples: Optional[bool] = False,
**kwargs: Any,
) -> Any:
if transaction is None:
Expand All @@ -1166,4 +1169,5 @@ def call_function(
transaction=transaction,
block_identifier=block_identifier,
ccip_read_enabled=ccip_read_enabled,
decode_tuples=decode_tuples,
)
31 changes: 15 additions & 16 deletions web3/contract/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,11 @@ def __init__(self, address: Optional[ChecksumAddress] = None) -> None:
)

self.functions = ContractFunctions(
self.abi, _w3, self.address, self.decode_tuples
self.abi, _w3, self.address, decode_tuples=self.decode_tuples
)
self.caller = ContractCaller(
self.abi, _w3, self.address, decode_tuples=self.decode_tuples
)
self.caller = ContractCaller(self.abi, _w3, self.address, self.decode_tuples)
self.events = ContractEvents(self.abi, _w3, self.address)
self.fallback = Contract.get_fallback_function(
self.abi,
Expand Down Expand Up @@ -272,17 +274,15 @@ def factory(
normalizers=normalizers,
),
)
# if contract.decode_tuples:
# breakpoint()
# else:
# print("passsing!", contract.all_functions())

contract.functions = ContractFunctions(
contract.abi, contract.w3, contract.decode_tuples
contract.abi, contract.w3, decode_tuples=contract.decode_tuples
)
contract.caller = ContractCaller(
contract.abi,
contract.w3,
contract.address,
decode_tuples=contract.decode_tuples,
)
contract.caller = ContractCaller(contract.abi, contract.w3, contract.address, contract.decode_tuples)
# if contract.decode_tuples:
# breakpoint()
contract.events = ContractEvents(contract.abi, contract.w3)
contract.fallback = Contract.get_fallback_function(
contract.abi,
Expand Down Expand Up @@ -376,6 +376,7 @@ def call(
block_identifier: BlockIdentifier = "latest",
state_override: Optional[CallOverride] = None,
ccip_read_enabled: Optional[bool] = None,
decode_tuples: Optional[bool] = False,
) -> Any:
"""
Execute a contract function call using the `eth_call` interface.
Expand Down Expand Up @@ -416,7 +417,7 @@ def call(
self.abi,
state_override,
ccip_read_enabled,
self.decode_tuples,
decode_tuples,
*self.args,
**self.kwargs,
)
Expand Down Expand Up @@ -477,7 +478,6 @@ def __init__(
ccip_read_enabled: Optional[bool] = None,
decode_tuples: Optional[bool] = False,
) -> None:
# breakpoint()
super().__init__(
abi=abi,
w3=w3,
Expand All @@ -493,18 +493,17 @@ def __call__(
self,
transaction: Optional[TxParams] = None,
block_identifier: BlockIdentifier = "latest",
state_override: Optional[CallOverride] = None,
ccip_read_enabled: Optional[bool] = None,
decode_tuples: Optional[bool] = False,
) -> "ContractCaller":
if transaction is None:
transaction = {}

return type(self)(
self.abi,
self.w3,
self.address,
transaction=transaction,
block_identifier=block_identifier,
ccip_read_enabled=ccip_read_enabled,
decode_tuples=decode_tuples,
decode_tuples=self.decode_tuples,
)

0 comments on commit 5b4c87f

Please sign in to comment.