Skip to content

Commit

Permalink
Fix Mypy GitHub Action (#70)
Browse files Browse the repository at this point in the history
* Simplify the mypy check and see if it fails

* Fix mypy issues

* Install Jax for mypy check
  • Loading branch information
tushuhei authored Jul 12, 2022
1 parent 893c61b commit 16c0be5
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 12 deletions.
9 changes: 3 additions & 6 deletions .github/workflows/style-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,16 @@ jobs:
run: |
pip install --upgrade pip
pip install ".[dev]"
pip install jax[cpu]
- name: Run isort
run: |
isort --diff --check .
- name: Run yapf
run: |
yapf --diff --recursive budoux tests scripts
- name: Run mypy
if: ${{ always() }}
uses: sasanquaneuf/mypy-github-action@a0c442aa252655d7736ce6696e06227ccdd62870
with:
checkName: python-style-check
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
mypy budoux tests scripts
- name: Run flake8
if: ${{ always() }}
uses: suo/flake8-github-action@3e87882219642e01aa8a6bbd03b4b0adb8542c2a
Expand Down
14 changes: 8 additions & 6 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from jax import device_put, jit
jax_installed = True
except ModuleNotFoundError:
import numpy as jnp
import numpy as jnp # type: ignore

EPS = np.finfo(float).eps # type: np.floating[typing.Any]

Expand Down Expand Up @@ -90,10 +90,12 @@ def pred(phis: typing.Dict[int, float],
alphas: npt.NDArray[np.float64]
y: npt.NDArray[np.int64]

alphas = jnp.array(list(phis.values()))
y = 2 * (X[:, list(phis.keys())]
== True) - 1 # noqa (cannot replace `==` with `is`)
return y.dot(alphas) > 0
alphas = jnp.array(list(phis.values())) # type: ignore
y = 2 * (
X[:, list(phis.keys())] == True # noqa (cannot replace `==` with `is`)
) - 1
result: npt.NDArray[np.bool_] = y.dot(alphas) > 0
return result


def split_dataset(
Expand Down Expand Up @@ -174,7 +176,7 @@ def fit(X_train: npt.NDArray[np.bool_],
X_test = device_put(X_test)
Y_test = device_put(Y_test)
N_train, M_train = X_train.shape
w = jnp.ones(N_train) / N_train
w = jnp.ones(N_train) / N_train # type: ignore
YX_train = Y_train[:, None] ^ X_train
for t in range(iters):
print('=== %s ===' % (t))
Expand Down

0 comments on commit 16c0be5

Please sign in to comment.