diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index e63c1e7..9e979ea 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -9,78 +9,45 @@ on: - cron: "0 2 * * 6" jobs: - check_version: + build: strategy: matrix: python-version: [ 3.8 ] - os: [ ubuntu-latest ] + os: [ ubuntu-latest, macOS-latest, windows-latest ] + fail-fast: false runs-on: ${{ matrix.os }} steps: - name: Checkout code - uses: actions/checkout@v2 + uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} + cache: pip + cache-dependency-path: setup.py - - name: Check version - if: (github.event_name == 'pull_request' && github.base_ref == 'master') - run: | - python -m pip install --upgrade pip - - python -m pip install git+https://github.com/google-research/torchsde.git - master_info=$(pip list | grep torchsde) - master_version=$(echo ${master_info} | cut -d " " -f2) - python -m pip uninstall -y torchsde - - python setup.py install - pr_info=$(pip list | grep torchsde) - pr_version=$(echo ${pr_info} | cut -d " " -f2) + - name: Install + run: pip install pytest -e . --only-binary=numpy,scipy,matplotlib,torch + env: + PIP_EXTRA_INDEX_URL: https://download.pytorch.org/whl/cpu - python -c "import itertools as it - import sys - - _, master_version, pr_version = sys.argv - - master_version_ = [int(i) for i in master_version.split('.')] - pr_version_ = [int(i) for i in pr_version.split('.')] - - master_version__ = tuple(m for p, m in it.zip_longest(pr_version_, master_version_, fillvalue=0)) - pr_version__ = tuple(p for p, m in it.zip_longest(pr_version_, master_version_, fillvalue=0)) - sys.exit(pr_version__ < master_version__)" ${master_version} ${pr_version} + - name: Test with pytest + run: python -m pytest -v - build: - needs: [ check_version ] - strategy: - matrix: - python-version: [ 3.6, 3.8 ] - os: [ ubuntu-latest, macOS-latest, windows-latest ] - fail-fast: false - runs-on: ${{ matrix.os }} + lint: + runs-on: ubuntu-latest steps: - name: Checkout code - uses: actions/checkout@v2 + uses: actions/checkout@v4 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + - uses: actions/setup-python@v4 with: - python-version: ${{ matrix.python-version }} - - - name: Install dependencies - run: | - python -m pip install --upgrade pip - python -m pip install flake8 pytest - - - name: Windows patch # Specifically for windows, since pip fails to fetch torch 1.6.0 as of Oct 2020. - if: runner.os == 'Windows' - run: python -m pip install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html + python-version: "3.11" + cache: pip + cache-dependency-path: setup.py - name: Lint with flake8 run: | + python -m pip install flake8 python -m flake8 . - - - name: Test with pytest - run: | - python setup.py install - python -m pytest diff --git a/README.md b/README.md index 767a95f..84f9827 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ This library provides [stochastic differential equation (SDE)](https://en.wikipe pip install torchsde ``` -**Requirements:** Python >=3.6 and PyTorch >=1.6.0. +**Requirements:** Python >=3.8 and PyTorch >=1.6.0. ## Documentation Available [here](./DOCUMENTATION.md). diff --git a/setup.py b/setup.py index 23e873a..66f3d1f 100644 --- a/setup.py +++ b/setup.py @@ -40,15 +40,12 @@ url="https://github.com/google-research/torchsde", packages=setuptools.find_packages(exclude=['benchmarks', 'diagnostics', 'examples', 'tests']), install_requires=[ - "boltons>=20.2.1", - "numpy==1.19;python_version<'3.7'", - "numpy>=1.19;python_version>='3.7'", - "scipy==1.5;python_version<'3.7'", - "scipy>=1.5;python_version>='3.7'", + "numpy>=1.19", + "scipy>=1.5", "torch>=1.6.0", "trampoline>=0.1.2", ], - python_requires='~=3.6', + python_requires='>=3.8', classifiers=[ "Programming Language :: Python :: 3", "License :: OSI Approved :: Apache Software License", diff --git a/torchsde/_brownian/brownian_interval.py b/torchsde/_brownian/brownian_interval.py index afde9cb..79391d2 100644 --- a/torchsde/_brownian/brownian_interval.py +++ b/torchsde/_brownian/brownian_interval.py @@ -16,7 +16,6 @@ import trampoline import warnings -import boltons.cacheutils import numpy as np import torch @@ -112,6 +111,21 @@ def __getitem__(self, item): raise KeyError +class _LRUDict(dict): + def __init__(self, max_size): + super().__init__() + self._max_size = max_size + self._keys = [] + + def __setitem__(self, key, value): + if key in self: + self._keys.remove(key) + elif len(self) >= self._max_size: + del self[self._keys.pop(0)] + super().__setitem__(key, value) + self._keys.append(key) + + class _Interval: # Intervals correspond to some subinterval of the overall interval [t0, t1]. # They are arranged as a binary tree: each node corresponds to an interval. If a node has children, they are left @@ -505,7 +519,7 @@ def __init__(self, elif cache_size == 0: self._increment_and_space_time_levy_area_cache = _EmptyDict() else: - self._increment_and_space_time_levy_area_cache = boltons.cacheutils.LRU(max_size=cache_size) + self._increment_and_space_time_levy_area_cache = _LRUDict(max_size=cache_size) # We keep track of the most recently queried interval, and start searching for the next interval from that # element of the binary tree. This is because subsequent queries are likely to be near the most recent query.