Skip to content

Commit

Permalink
Fix tests on main (#2739)
Browse files Browse the repository at this point in the history
* Start

* Fixings
  • Loading branch information
muellerzr authored May 3, 2024
1 parent 6ac27e2 commit 060361f
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/accelerate/test_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
require_single_gpu,
require_single_xpu,
require_torch_min_version,
require_torchvision,
require_tpu,
require_xpu,
skip,
Expand Down
8 changes: 8 additions & 0 deletions src/accelerate/test_utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
is_timm_available,
is_torch_version,
is_torch_xla_available,
is_torchvision_available,
is_transformers_available,
is_wandb_available,
is_xpu_available,
Expand Down Expand Up @@ -214,6 +215,13 @@ def require_timm(test_case):
return unittest.skipUnless(is_timm_available(), "test requires the timm library")(test_case)


def require_torchvision(test_case):
"""
Decorator marking a test that requires torchvision. These tests are skipped when they are not.
"""
return unittest.skipUnless(is_torchvision_available(), "test requires the torchvision library")(test_case)


def require_schedulefree(test_case):
"""
Decorator marking a test that requires schedulefree. These tests are skipped when they are not.
Expand Down
1 change: 1 addition & 0 deletions src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
is_tensorboard_available,
is_timm_available,
is_torch_xla_available,
is_torchvision_available,
is_transformer_engine_available,
is_transformers_available,
is_wandb_available,
Expand Down
4 changes: 4 additions & 0 deletions src/accelerate/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,10 @@ def is_bnb_available():
return _is_package_available("bitsandbytes")


def is_torchvision_available():
return _is_package_available("torchvision")


def is_megatron_lm_available():
if str_to_bool(os.environ.get("ACCELERATE_USE_MEGATRON_LM", "False")) == 1:
package_exists = importlib.util.find_spec("megatron") is not None
Expand Down
2 changes: 2 additions & 0 deletions tests/test_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ def test_prepared_objects_are_referenced(self):

def test_free_memory_dereferences_prepared_components(self):
accelerator = Accelerator()
# Free up refs with empty_cache() and gc.collect()
accelerator.free_memory()
model, optimizer, scheduler, train_dl, valid_dl = create_components()
free_cpu_ram_before = psutil.virtual_memory().available // 1024 // 1024
model, optimizer, scheduler, train_dl, valid_dl = accelerator.prepare(
Expand Down
2 changes: 2 additions & 0 deletions tests/test_multigpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
require_multi_gpu,
require_non_torch_xla,
require_pippy,
require_torchvision,
)
from accelerate.utils import patch_environment

Expand Down Expand Up @@ -76,6 +77,7 @@ def test_distributed_data_loop(self):

@require_multi_gpu
@require_pippy
@require_torchvision
@require_huggingface_suite
def test_pippy(self):
"""
Expand Down

0 comments on commit 060361f

Please sign in to comment.