Skip to content

Commit

Permalink
fix pytorch extractor tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mastoffel committed Feb 7, 2025
1 parent c0c85aa commit 4ea7446
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions tests/test_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from autoemulate.emulators import GaussianProcess
from autoemulate.emulators import RandomForest
from autoemulate.utils import _extract_pytorch_model
from autoemulate.utils import extract_pytorch_model


@pytest.fixture
Expand All @@ -35,42 +35,42 @@ def Xy():
# test_error_when_not_fitted
def test_error_when_not_fitted(pytorch_model):
with pytest.raises(ValueError):
_extract_pytorch_model(pytorch_model)
extract_pytorch_model(pytorch_model)


# test standalone model
def test_extract_when_fitted(pytorch_model, Xy):
pytorch_model.fit(*Xy)
model = _extract_pytorch_model(pytorch_model)
model = extract_pytorch_model(pytorch_model)
assert isinstance(model, torch.nn.Module)


def test_error_when_not_pytorch_model(non_pytorch_model, Xy):
non_pytorch_model.fit(*Xy)
with pytest.raises(ValueError):
_extract_pytorch_model(non_pytorch_model)
extract_pytorch_model(non_pytorch_model)


def test_error_when_multiout_model(non_pytorch_multiout_model, Xy):
non_pytorch_multiout_model.fit(*Xy)
with pytest.raises(ValueError):
_extract_pytorch_model(non_pytorch_multiout_model)
extract_pytorch_model(non_pytorch_multiout_model)


# test pipeline
def test_extract_when_fitted_pipeline(pytorch_model, Xy):
pytorch_model.fit(*Xy)
model = _extract_pytorch_model(pytorch_model)
model = extract_pytorch_model(pytorch_model)
assert isinstance(model, torch.nn.Module)


def test_error_when_non_pytorch_pipeline(non_pytorch_model, Xy):
non_pytorch_model.fit(*Xy)
with pytest.raises(ValueError):
_extract_pytorch_model(non_pytorch_model)
extract_pytorch_model(non_pytorch_model)


def test_error_when_multiout_pipeline(non_pytorch_multiout_model, Xy):
non_pytorch_multiout_model.fit(*Xy)
with pytest.raises(ValueError):
_extract_pytorch_model(non_pytorch_multiout_model)
extract_pytorch_model(non_pytorch_multiout_model)

0 comments on commit 4ea7446

Please sign in to comment.