Skip to content

Commit

Permalink
torch: Check status of torch numpy bindings instead of specific version
Browse files Browse the repository at this point in the history
Assert that torch is usable if installed during testing.

Signed-off-by: Jan Vesely <[email protected]>
  • Loading branch information
jvesely committed May 9, 2023
1 parent c345554 commit 3f86976
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 24 deletions.
20 changes: 7 additions & 13 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,17 @@
import re
import sys

from psyneulink import clear_registry, primary_registries
from psyneulink import clear_registry, primary_registries, torch_available
from psyneulink.core import llvm as pnlvm
from psyneulink.core.globals.utilities import set_global_seed


try:
import torch

# If we are on windows and using Python 3.10, despite it importing correctly, PyTorch is currently broken,
# see https://pytorch.org/get-started/locally/ showing lack of support.
if sys.platform.startswith("win32") and sys.version_info >= (3, 10):
pytorch_available = False
else:
pytorch_available = True

except ImportError:
pytorch_available = False
pass
else:
# Check that torch is usable if installed
assert torch_available, "Torch module is available, but not usable by PNL"

# def pytest_addoption(parser):
# parser.addoption(
Expand Down Expand Up @@ -57,7 +51,7 @@ def pytest_runtest_setup(item):
if 'cuda' in item.keywords and not pnlvm.ptx_enabled:
pytest.skip('PTX engine not enabled/available')

if 'pytorch' in item.keywords and not pytorch_available:
if 'pytorch' in item.keywords and not torch_available:
pytest.skip('pytorch not available')

doctest.ELLIPSIS_MARKER = "[...]"
Expand Down Expand Up @@ -122,7 +116,7 @@ def pytest_runtest_call(item):
set_global_seed(seed)

if 'pytorch' in item.keywords:
assert pytorch_available
assert torch_available
torch.manual_seed(seed)


Expand Down
14 changes: 14 additions & 0 deletions psyneulink/library/compositions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,17 @@
__all__ = list(regressioncfa.__all__)
__all__.extend(compositionrunner.__all__)
__all__.extend(autodiffcomposition.__all__)

try:
import torch
from torch import nn

# Some torch releases have silent dependency on a more recent numpy than the one curently required by PNL.
# This breaks torch numpy bindings, see e.g: https://github.com/pytorch/pytorch/issues/100690
torch.tensor([1,2,3]).numpy()

torch_available = True
except (ImportError, RuntimeError):
torch_available = False

__all__.append('torch_available')
13 changes: 2 additions & 11 deletions psyneulink/library/compositions/pytorchmodelcreator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import graph_scheduler
import torch
import torch.nn as nn

from psyneulink.core.components.component import Component, ComponentsMeta
from psyneulink.core.compositions.composition import NodeRole
Expand All @@ -11,23 +13,12 @@
from psyneulink.core.globals.utilities import get_deepcopy_with_shared
from .pytorchcomponents import *

try:
import torch
from torch import nn
torch_available = True
except ImportError:
torch_available = False

__all__ = ['PytorchModelCreator']

class PytorchModelCreator(torch.nn.Module):
# sets up parameters of model & the information required for forward computation
def __init__(self, composition, device, context=None):

if not torch_available:
raise Exception('Pytorch python module (torch) is not installed. Please install it with '
'`pip install torch` or `pip3 install torch`')

super(PytorchModelCreator, self).__init__()

# Maps Mechanism -> PytorchMechanismWrapper
Expand Down

0 comments on commit 3f86976

Please sign in to comment.