Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(WIP) Fix CI #1

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 21 additions & 54 deletions .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,78 +9,45 @@ on:
- cron: "0 2 * * 6"

jobs:
check_version:
build:
strategy:
matrix:
python-version: [ 3.8 ]
os: [ ubuntu-latest ]
os: [ ubuntu-latest, macOS-latest, windows-latest ]
fail-fast: false
runs-on: ${{ matrix.os }}
steps:
- name: Checkout code
uses: actions/checkout@v2
uses: actions/checkout@v4

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
cache: pip
cache-dependency-path: setup.py

- name: Check version
if: (github.event_name == 'pull_request' && github.base_ref == 'master')
run: |
python -m pip install --upgrade pip

python -m pip install git+https://github.com/google-research/torchsde.git
master_info=$(pip list | grep torchsde)
master_version=$(echo ${master_info} | cut -d " " -f2)
python -m pip uninstall -y torchsde

python setup.py install
pr_info=$(pip list | grep torchsde)
pr_version=$(echo ${pr_info} | cut -d " " -f2)
- name: Install
run: pip install pytest -e . --only-binary=numpy,scipy,matplotlib,torch
env:
PIP_EXTRA_INDEX_URL: https://download.pytorch.org/whl/cpu

python -c "import itertools as it
import sys

_, master_version, pr_version = sys.argv

master_version_ = [int(i) for i in master_version.split('.')]
pr_version_ = [int(i) for i in pr_version.split('.')]

master_version__ = tuple(m for p, m in it.zip_longest(pr_version_, master_version_, fillvalue=0))
pr_version__ = tuple(p for p, m in it.zip_longest(pr_version_, master_version_, fillvalue=0))
sys.exit(pr_version__ < master_version__)" ${master_version} ${pr_version}
- name: Test with pytest
run: python -m pytest -v

build:
needs: [ check_version ]
strategy:
matrix:
python-version: [ 3.6, 3.8 ]
os: [ ubuntu-latest, macOS-latest, windows-latest ]
fail-fast: false
runs-on: ${{ matrix.os }}
lint:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v2
uses: actions/checkout@v4

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
- uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}

- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest

- name: Windows patch # Specifically for windows, since pip fails to fetch torch 1.6.0 as of Oct 2020.
if: runner.os == 'Windows'
run: python -m pip install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
python-version: "3.11"
cache: pip
cache-dependency-path: setup.py

- name: Lint with flake8
run: |
python -m pip install flake8
python -m flake8 .

- name: Test with pytest
run: |
python setup.py install
python -m pytest
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ This library provides [stochastic differential equation (SDE)](https://en.wikipe
pip install torchsde
```

**Requirements:** Python >=3.6 and PyTorch >=1.6.0.
**Requirements:** Python >=3.8 and PyTorch >=1.6.0.

## Documentation
Available [here](./DOCUMENTATION.md).
Expand Down
9 changes: 3 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,12 @@
url="https://github.com/google-research/torchsde",
packages=setuptools.find_packages(exclude=['benchmarks', 'diagnostics', 'examples', 'tests']),
install_requires=[
"boltons>=20.2.1",
"numpy==1.19;python_version<'3.7'",
"numpy>=1.19;python_version>='3.7'",
"scipy==1.5;python_version<'3.7'",
"scipy>=1.5;python_version>='3.7'",
"numpy>=1.19",
"scipy>=1.5",
"torch>=1.6.0",
"trampoline>=0.1.2",
],
python_requires='~=3.6',
python_requires='>=3.8',
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License",
Expand Down
18 changes: 16 additions & 2 deletions torchsde/_brownian/brownian_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import trampoline
import warnings

import boltons.cacheutils
import numpy as np
import torch

Expand Down Expand Up @@ -112,6 +111,21 @@ def __getitem__(self, item):
raise KeyError


class _LRUDict(dict):
def __init__(self, max_size):
super().__init__()
self._max_size = max_size
self._keys = []

def __setitem__(self, key, value):
if key in self:
self._keys.remove(key)
elif len(self) >= self._max_size:
del self[self._keys.pop(0)]
super().__setitem__(key, value)
self._keys.append(key)


class _Interval:
# Intervals correspond to some subinterval of the overall interval [t0, t1].
# They are arranged as a binary tree: each node corresponds to an interval. If a node has children, they are left
Expand Down Expand Up @@ -505,7 +519,7 @@ def __init__(self,
elif cache_size == 0:
self._increment_and_space_time_levy_area_cache = _EmptyDict()
else:
self._increment_and_space_time_levy_area_cache = boltons.cacheutils.LRU(max_size=cache_size)
self._increment_and_space_time_levy_area_cache = _LRUDict(max_size=cache_size)

# We keep track of the most recently queried interval, and start searching for the next interval from that
# element of the binary tree. This is because subsequent queries are likely to be near the most recent query.
Expand Down