diff --git a/tests/conftest.py b/tests/conftest.py index e673f17b35..6eb34a3e0a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,6 +15,7 @@ from web3.contract import Contract from web3.providers.eth_tester import EthereumTesterProvider +from tests.utils import working_directory from vyper import compiler from vyper.ast.grammar import parse_vyper_source from vyper.codegen.ir_node import IRnode @@ -79,6 +80,12 @@ def debug(pytestconfig): _set_debug_mode(debug) +@pytest.fixture +def chdir_tmp_path(tmp_path): + with working_directory(tmp_path): + yield + + @pytest.fixture def keccak(): return Web3.keccak diff --git a/tests/functional/syntax/modules/test_initializers.py b/tests/functional/syntax/modules/test_initializers.py index a12f5f57ea..d0523153c8 100644 --- a/tests/functional/syntax/modules/test_initializers.py +++ b/tests/functional/syntax/modules/test_initializers.py @@ -326,7 +326,7 @@ def foo(): assert e.value._hint == "did you mean `m := lib1`?" -def test_global_initializer_constraint(make_input_bundle): +def test_global_initializer_constraint(make_input_bundle, chdir_tmp_path): lib1 = """ counter: uint256 """ @@ -818,7 +818,7 @@ def foo(new_value: uint256): assert e.value._hint == expected_hint -def test_invalid_uses(make_input_bundle): +def test_invalid_uses(make_input_bundle, chdir_tmp_path): lib1 = """ counter: uint256 """ @@ -848,7 +848,7 @@ def foo(): assert e.value._hint == "delete `uses: lib1`" -def test_invalid_uses2(make_input_bundle): +def test_invalid_uses2(make_input_bundle, chdir_tmp_path): # test a more complicated invalid uses lib1 = """ counter: uint256 diff --git a/tests/unit/cli/vyper_compile/test_compile_files.py b/tests/unit/cli/vyper_compile/test_compile_files.py index 2a65d66835..6adee24db6 100644 --- a/tests/unit/cli/vyper_compile/test_compile_files.py +++ b/tests/unit/cli/vyper_compile/test_compile_files.py @@ -1,3 +1,5 @@ +import contextlib +import sys from pathlib import Path import pytest @@ -257,3 +259,34 @@ def foo() -> uint256: contract_file = make_file("contract.vy", contract_source) assert compile_files([contract_file], ["combined_json"], paths=[tmp_path]) is not None + + +@contextlib.contextmanager +def mock_sys_path(path): + try: + sys.path.append(path) + yield + finally: + sys.path.pop() + + +def test_import_sys_path(tmp_path_factory, make_file): + library_source = """ +@internal +def foo() -> uint256: + return block.number + 1 + """ + contract_source = """ +import lib + +@external +def foo() -> uint256: + return lib.foo() + """ + tmpdir = tmp_path_factory.mktemp("test-sys-path") + with open(tmpdir / "lib.vy", "w") as f: + f.write(library_source) + + contract_file = make_file("contract.vy", contract_source) + with mock_sys_path(tmpdir): + assert compile_files([contract_file], ["combined_json"]) is not None diff --git a/vyper/cli/vyper_compile.py b/vyper/cli/vyper_compile.py index d6ba9e180a..ac69cf3310 100755 --- a/vyper/cli/vyper_compile.py +++ b/vyper/cli/vyper_compile.py @@ -238,10 +238,18 @@ def compile_files( storage_layout_paths: list[str] = None, no_bytecode_metadata: bool = False, ) -> dict: - paths = paths or [] + # lowest precedence search path is always sys path + search_paths = [Path(p) for p in sys.path] + + # python sys path uses opposite resolution order from us + # (first in list is highest precedence; we give highest precedence + # to the last in the list) + search_paths.reverse() - # lowest precedence search path is always `.` - search_paths = [Path(".")] + if Path(".") not in search_paths: + search_paths.append(Path(".")) + + paths = paths or [] for p in paths: path = Path(p).resolve(strict=True) diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index e50c3e6d6f..9d2cef6eee 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -695,10 +695,22 @@ def _load_import_helper( def _parse_and_fold_ast(file: FileInput) -> vy_ast.Module: + module_path = file.resolved_path # for error messages + try: + # try to get a relative path, to simplify the error message + cwd = Path(".") + if module_path.is_absolute(): + cwd = cwd.resolve() + module_path = module_path.relative_to(cwd) + except ValueError: + # we couldn't get a relative path (cf. docs for Path.relative_to), + # use the resolved path given to us by the InputBundle + pass + ret = vy_ast.parse_to_ast( file.source_code, source_id=file.source_id, - module_path=str(file.path), + module_path=str(module_path), resolved_path=str(file.resolved_path), ) return ret