Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix mypy checks and jax installation in CI pipeline #6945

Merged
merged 7 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -358,12 +358,12 @@ jobs:
- name: Cache conda
uses: actions/cache@v3
env:
# Increase this value to reset cache if environment-test.yml has not changed
# Increase this value to reset cache if environment-jax.yml has not changed
CACHE_NUMBER: 0
with:
path: ~/conda_pkgs_dir
key: ${{ runner.os }}-py${{matrix.python-version}}-conda-${{ env.CACHE_NUMBER }}-${{
hashFiles('conda-envs/environment-test.yml') }}
hashFiles('conda-envs/environment-jax.yml') }}
- name: Cache multiple paths
uses: actions/cache@v3
env:
Expand All @@ -383,7 +383,7 @@ jobs:
mamba-version: "*"
activate-environment: pymc-test
channel-priority: strict
environment-file: conda-envs/environment-test.yml
environment-file: conda-envs/environment-jax.yml
python-version: ${{matrix.python-version}}
use-mamba: true
use-only-tar-bz2: false # IMPORTANT: This may break caching of conda packages! See https://github.com/conda-incubator/setup-miniconda/issues/267
Expand All @@ -392,10 +392,6 @@ jobs:
conda activate pymc-test
pip install -e .
python --version
- name: Install external samplers
run: |
conda activate pymc-test
pip install "numpyro>=0.8.0" "blackjax>=1.0.0"
- name: Run tests
run: |
python -m pytest -vv --cov=pymc --cov-report=xml --no-cov-on-fail --cov-report term --durations=50 $TEST_SUBSET
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ dependencies:
- watermark
- polyagamma
- sphinx-remove-toctrees
- mypy=0.990
- mypy=1.5.1
- types-cachetools
- pip:
- git+https://github.com/pymc-devs/pymc-sphinx-theme
Expand Down
38 changes: 38 additions & 0 deletions conda-envs/environment-jax.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# "test" conda envs are used to set up our CI environment in GitHub actions
name: pymc-test
channels:
- conda-forge
- defaults
dependencies:
# Base dependencies
- arviz>=0.13.0
- blas
- cachetools>=4.2.1
- cloudpickle
- fastprogress>=0.2.0
- h5py>=2.7
# Jaxlib version must not be greater than jax version!
- blackjax>=1.0.0
- jaxlib==0.4.14
- jax==0.4.16
- libblas=*=*mkl
- mkl-service
- numpy>=1.15.0
- numpyro>=0.8.0
- pandas>=0.24.0
- pip
- pytensor>=2.17.0,<2.18
- python-graphviz
- networkx
- scipy>=1.4.1
- typing-extensions>=3.7.4
# Extra dependencies for testing
- ipython>=7.16
- pre-commit>=2.8.0
- pytest-cov>=2.5
- pytest>=3.0
- mypy=1.5.1
- types-cachetools
- pip:
- numdifftools>=0.9.40
- mcbackend>=0.4.0
2 changes: 1 addition & 1 deletion conda-envs/environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ dependencies:
- pre-commit>=2.8.0
- pytest-cov>=2.5
- pytest>=3.0
- mypy=0.990
- mypy=1.5.1
- types-cachetools
- pip:
- numdifftools>=0.9.40
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ dependencies:
- sphinx>=1.5
- watermark
- sphinx-remove-toctrees
- mypy=0.990
- mypy=1.5.1
- types-cachetools
- pip:
- git+https://github.com/pymc-devs/pymc-sphinx-theme
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ dependencies:
- pre-commit>=2.8.0
- pytest-cov>=2.5
- pytest>=3.0
- mypy=0.990
- mypy=1.5.1
- types-cachetools
- pip:
- numdifftools>=0.9.40
Expand Down
12 changes: 6 additions & 6 deletions pymc/gp/hsgp_approx.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,9 @@ def __init__(
self._drop_first = drop_first
self._m = m
self._m_star = int(np.prod(self._m))
self._L = L
self._L: Optional[pt.TensorVariable] = None
if L is not None:
self._L = pt.as_tensor(L)
self._c = c

super().__init__(mean_func=mean_func, cov_func=cov_func)
Expand All @@ -198,13 +200,13 @@ def __add__(self, other):
raise NotImplementedError("Additive HSGPs aren't supported.")

@property
def L(self):
def L(self) -> pt.TensorVariable:
if self._L is None:
raise RuntimeError("Boundaries `L` required but still unset.")
return self._L

@L.setter
def L(self, value):
def L(self, value: TensorLike):
self._L = pt.as_tensor_variable(value)

def prior_linearized(self, Xs: TensorLike):
Expand Down Expand Up @@ -290,9 +292,7 @@ def prior_linearized(self, Xs: TensorLike):
# If not provided, use Xs and c to set L
if self._L is None:
assert isinstance(self._c, (numbers.Real, np.ndarray, pt.TensorVariable))
self.L = set_boundary(Xs, self._c)
else:
self.L = self._L
self._L = pt.as_tensor(set_boundary(Xs, self._c))

eigvals = calc_eigenvalues(self.L, self._m, tl=pt)
phi = calc_eigenvectors(Xs, self.L, eigvals, self._m, tl=pt)
Expand Down
5 changes: 3 additions & 2 deletions pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
from pymc.initial_point import make_initial_point_fn
from pymc.logprob.basic import transformed_conditional_logp
from pymc.logprob.utils import ParameterValueError
from pymc.model_graph import VarName, model_to_graphviz
from pymc.model_graph import model_to_graphviz
from pymc.pytensorf import (
PointFunc,
SeedSequenceSeed,
Expand All @@ -80,6 +80,7 @@
)
from pymc.util import (
UNSET,
VarName,
WithMemoization,
_add_future_warning_tag,
get_transformed_name,
Expand Down Expand Up @@ -2061,7 +2062,7 @@ def compile_fn(
)


def Point(*args, filter_model_vars=False, **kwargs) -> Dict[str, np.ndarray]:
def Point(*args, filter_model_vars=False, **kwargs) -> Dict[VarName, np.ndarray]:
"""Build a point. Uses same args as dict() does.
Filters out variables not in the model. All keys are strings.

Expand Down
8 changes: 5 additions & 3 deletions pymc/model/transform/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def prune_vars_detached_from_observed(model: Model) -> Model:


def parse_vars(model: Model, vars: Union[ModelVariable, Sequence[ModelVariable]]) -> List[Variable]:
if not isinstance(vars, (list, tuple)):
vars = (vars,)
return [model[var] if isinstance(var, str) else var for var in vars]
if isinstance(vars, (list, tuple)):
vars_seq = vars
else:
vars_seq = (vars,)
return [model[var] if isinstance(var, str) else var for var in vars_seq]
19 changes: 8 additions & 11 deletions pymc/model_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import warnings

from collections import defaultdict
from typing import Dict, Iterable, List, NewType, Optional, Sequence, Set
from typing import Dict, Iterable, List, Optional, Sequence, Set

from pytensor import function
from pytensor.compile.sharedvalue import SharedVariable
Expand All @@ -28,10 +28,7 @@

import pymc as pm

from pymc.util import get_default_varnames, get_var_name

VarName = NewType("VarName", str)

from pymc.util import VarName, get_default_varnames, get_var_name

__all__ = (
"ModelGraph",
Expand Down Expand Up @@ -76,12 +73,12 @@ def _expand(x):
return reversed(_filter_non_parameter_inputs(x))
return []

parents = {
VarName(get_var_name(x))
for x in walk(nodes=_filter_non_parameter_inputs(var), expand=_expand)
parents = set()
for x in walk(nodes=_filter_non_parameter_inputs(var), expand=_expand):
# Only consider nodes that are in the named model variables.
if x.name and x.name in self._all_var_names
}
vname = getattr(x, "name", None)
if isinstance(vname, str) and vname in self._all_var_names:
parents.add(VarName(vname))

return parents

Expand Down Expand Up @@ -113,7 +110,7 @@ def vars_to_plot(self, var_names: Optional[Iterable[VarName]] = None) -> List[Va
selected_ancestors.add(self.model.rvs_to_values[var])

# ordering of self._all_var_names is important
return [VarName(var.name) for var in selected_ancestors]
return [get_var_name(var) for var in selected_ancestors]

def make_compute_graph(
self, var_names: Optional[Iterable[VarName]] = None
Expand Down
4 changes: 2 additions & 2 deletions pymc/sampling/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def sample_blackjax_nuts(
var_names: Optional[Sequence[str]] = None,
keep_untransformed: bool = False,
chain_method: str = "parallel",
postprocessing_backend: Literal["cpu", "gpu"] | None = None,
postprocessing_backend: Optional[Literal["cpu", "gpu"]] = None,
postprocessing_vectorize: Literal["vmap", "scan"] = "scan",
idata_kwargs: Optional[Dict[str, Any]] = None,
postprocessing_chunks=None, # deprecated
Expand Down Expand Up @@ -546,7 +546,7 @@ def sample_numpyro_nuts(
progressbar: bool = True,
keep_untransformed: bool = False,
chain_method: str = "parallel",
postprocessing_backend: Literal["cpu", "gpu"] | None = None,
postprocessing_backend: Optional[Literal["cpu", "gpu"]] = None,
postprocessing_vectorize: Literal["vmap", "scan"] = "scan",
idata_kwargs: Optional[Dict] = None,
nuts_kwargs: Optional[Dict] = None,
Expand Down
8 changes: 5 additions & 3 deletions pymc/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import functools
import warnings

from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
from typing import Any, Dict, List, NewType, Optional, Sequence, Tuple, Union, cast

import arviz
import cloudpickle
Expand All @@ -29,6 +29,8 @@

from pymc.exceptions import BlockModelAccessError

VarName = NewType("VarName", str)


class _UnsetType:
"""Type for the `UNSET` object to make it look nice in `help(...)` outputs."""
Expand Down Expand Up @@ -207,9 +209,9 @@ def get_default_varnames(var_iterator, include_transformed):
return [var for var in var_iterator if not is_transformed_name(get_var_name(var))]


def get_var_name(var) -> str:
def get_var_name(var) -> VarName:
"""Get an appropriate, plain variable name for a variable."""
return str(getattr(var, "name", var))
return VarName(str(getattr(var, "name", var)))


def get_transformed(z):
Expand Down
24 changes: 18 additions & 6 deletions pymc/variational/opvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
import itertools
import warnings

from typing import Any
from typing import Any, overload

import numpy as np
import pytensor
Expand Down Expand Up @@ -980,17 +980,29 @@ def symbolic_random(self):
"""
raise NotImplementedError

@pytensor.config.change_flags(compute_test_value="off")
@overload
def set_size_and_deterministic(
self, node: Variable, s, d: bool, more_replacements: dict | None = None
) -> Variable:
...

@overload
def set_size_and_deterministic(
self, node: Variable, s, d: bool, more_replacements: dict = None
self, node: list[Variable], s, d: bool, more_replacements: dict | None = None
) -> list[Variable]:
...

@pytensor.config.change_flags(compute_test_value="off")
def set_size_and_deterministic(
self, node: Variable | list[Variable], s, d: bool, more_replacements: dict | None = None
) -> Variable | list[Variable]:
"""*Dev* - after node is sampled via :func:`symbolic_sample_over_posterior` or
:func:`symbolic_single_sample` new random generator can be allocated and applied to node

Parameters
----------
node: :class:`Variable`
PyTensor node with symbolically applied VI replacements
node
PyTensor node(s) with symbolically applied VI replacements
s: scalar
desired number of samples
d: bool or int
Expand All @@ -1000,7 +1012,7 @@ def set_size_and_deterministic(

Returns
-------
:class:`Variable` with applied replacements, ready to use
:class:`Variable` or list with applied replacements, ready to use
"""

flat2rand = self.make_size_and_deterministic_replacements(s, d, more_replacements)
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ h5py>=2.7
ipython>=7.16
jupyter-sphinx
mcbackend>=0.4.0
mypy==0.990
mypy==1.5.1
myst-nb
numdifftools>=0.9.40
numpy>=1.15.0
Expand Down
1 change: 1 addition & 0 deletions scripts/generate_pip_deps_from_conda.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
"networkx",
"blas",
"jax",
"jaxlib",
}
RENAME = {}

Expand Down
4 changes: 0 additions & 4 deletions scripts/run_mypy.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,10 @@
pymc/logprob/utils.py
pymc/model/core.py
pymc/model/fgraph.py
pymc/model/transform/basic.py
pymc/model/transform/conditioning.py
pymc/model_graph.py
pymc/printing.py
pymc/pytensorf.py
pymc/sampling/jax.py
pymc/variational/opvi.py
"""


Expand Down Expand Up @@ -105,7 +102,6 @@ def check_no_unexpected_results(mypy_lines: Iterator[str]):
Exits the process with non-zero exit code upon unexpected results.
"""
df = mypy_to_pandas(mypy_lines)

all_files = {
str(fp).replace(str(DP_ROOT), "").strip(os.sep).replace(os.sep, "/")
for fp in DP_ROOT.glob("pymc/**/*.py")
Expand Down