Skip to content

Commit

Permalink
name change to bijax
Browse files Browse the repository at this point in the history
  • Loading branch information
patel-zeel committed Jul 13, 2022
1 parent 884ccbc commit 3387d99
Show file tree
Hide file tree
Showing 18 changed files with 151 additions and 215 deletions.
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ __pycache__/
*.gif
*.txt
*.egg-info/
abi_jax/_version.py
bijax/_version.py
build
tmp.py
22 changes: 11 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
## ABI_JAX
## BIJAX

Approximate Bayesian Inference in JAX.
Bayesian Inference in JAX.

## Installation

```
pip install git+https://github.com/patel-zeel/abi_jax.git
pip install git+https://github.com/patel-zeel/bijax.git
```

## Methods implemented in abi_jax
## Methods implemented in BIJAX

* `from abi_jax.advi import ADVI` - [Automatic Differentiation Variational Inference](https://arxiv.org/abs/1603.00788)
* `from abi_jax.laplace import ADLaplace` - Automatic Differentiation Laplace approximation.
* `from abi_jax.mcmc import MCMC` - A helper class for external Markov Chain Monte Carlo (MCMC) sampling.
* `from bijax.advi import ADVI` - [Automatic Differentiation Variational Inference](https://arxiv.org/abs/1603.00788)
* [WIP]`from bijax.laplace import ADLaplace` - Automatic Differentiation Laplace approximation.
* `from bijax.mcmc import MCMC` - A helper class for external Markov Chain Monte Carlo (MCMC) sampling.

## How to use abi_jax?
## How to use BIJAX?

abi_jax is built without too many layers of abstractions or some new conventions. Thus, it is also useful for educational purposes. If you like to directly dive into the examples, please refer to the [examples](examples) directory.
BIJAX is built without layers of abstractions or proposing new conventions. Thus, it is also useful for educational purposes. If you like to directly dive into the examples, please refer to the [examples](examples) directory.


There are a few core components of abi_jax:
There are a few core components of bijax:

### Prior
`tensoflow_probability.substrates.jax` should be used to define the distributions for prior.
Expand Down Expand Up @@ -110,7 +110,7 @@ params = model.init(seed)
```

### Optimization
Models in abi_jax have `loss_fn` method which can be used to compute the loss. The loss can be optimized with any method that work with `JAX`. We also have a utility function `from abi_jax.utils import train` to train the model using `optax` optimizers.
Models in bijax have `loss_fn` method which can be used to compute the loss. The loss can be optimized with any method that work with `JAX`. We also have a utility function `from bijax.utils import train` to train the model using `optax` optimizers.

### Get the posterior distribution
Some of the models (`ADVI` and `ADLaplace`) support `.apply()` method to get the posterior distribution.
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
58 changes: 29 additions & 29 deletions examples/advi/bayesian_neural_network.ipynb

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions examples/advi/coin_toss.ipynb

Large diffs are not rendered by default.

55 changes: 29 additions & 26 deletions examples/advi/gaussian_mixture_model.ipynb

Large diffs are not rendered by default.

123 changes: 28 additions & 95 deletions examples/advi/gp_classification.ipynb

Large diffs are not rendered by default.

86 changes: 43 additions & 43 deletions examples/advi/neural_linear_model.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions examples/mcmc/eight_schools_unfinished.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"\n",
"from ajax.mcmc import MCMC\n",
"from bijax.mcmc import MCMC\n",
"\n",
"from functools import partial\n",
"import regdata as rd\n",
Expand Down Expand Up @@ -396,7 +396,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:ajax]",
"display_name": "Python [conda env:bijax]",
"language": "python",
"name": "conda-env-ajax-py"
},
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ requires = [
]

[tool.setuptools_scm]
write_to = "abi_jax/_version.py"
write_to = "bijax/_version.py"

[tool.black]
line-length = 120
Expand Down
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
[metadata]
name = abi_jax
name = bijax
author = Zeel B Patel
author-email = [email protected]
description = Approximate Bayesian Inference in JAX
url = https://github.com/patel-zeel/abi_jax
url = https://github.com/patel-zeel/bijax
license = MIT
long_description_content_type = text/markdown
long_description = file: README.md

0 comments on commit 3387d99

Please sign in to comment.