Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to PyMC5 #108

Merged
merged 2 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ jobs:
matrix:
python-version: [
"3.11",
"3.12"
]
steps:
- uses: actions/checkout@v4
Expand All @@ -29,4 +30,4 @@ jobs:
run: ruff check

- name: Tests
run: THEANO_FLAGS="blas.ldflags=" pytest tests -r a
run: pytest tests -r a
21 changes: 2 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Clone the repository and checkout the current development branch.
Create a virtual environment.

```
python3.11 -m venv env
python3 -m venv env
```

Activate the environment (adjust if using a non-default shell)
Expand All @@ -30,31 +30,14 @@ Install ATTRICI as a local development version with dev dependencies included
pip install -e .[dev]
```

At the current development stage ATTRICI uses Theano and you may need to create a Theano config file at `~/.theanorc` with settings like the following:

```
[global]
device = cpu
floatX = float64
cxx = g++
mode = FAST_RUN
openmp = True

[gcc]
cxxflags = -O3 -march=native -funroll-loops

[blas]
ldflags =
```


## Usage

See [USAGE.md](USAGE.md) for examples.

## Credits

We rely on the [pymc3](https://github.com/pymc-devs/pymc3) package for probabilistic programming (Salvatier et al. 2016).
We rely on the [PyMC](https://www.pymc.io/) package for probabilistic programming (Salvatier et al. 2016).

An early version of the code on Bayesian estimation of parameters in timeseries with periodicity in PyMC3 was inspired by [Ritchie Vink's](https://www.ritchievink.com) [post](https://www.ritchievink.com/blog/2018/10/09/build-facebooks-prophet-in-pymc3-bayesian-time-series-analyis-with-generalized-additive-models/) on Bayesian timeseries analysis with additive models.

Expand Down
10 changes: 0 additions & 10 deletions USAGE.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ attrici detrend --config runconfig.toml

As a computationally expensive operation, the `detrend` sub-command is designed to be run in parallel (for each geographical cell).
To make use of this parallelization, specify the arguments `--task-id ID` and `--task-count COUNT` and start several instances with `N` going from `0` to `N-1`. `N` does not have to equal the number of cells - these will be distributed to instances accordingly.
As, at this stage, the Theano library is used that compiles the estimatin model into a cache, make sure that the cache directory is different for each instance (using the `THEANO_FLAG` option `base_compiledir`; unfortunately, this implies that some of the joint compilation cannot be cached).

For the SLURM scheduler, which is widely used on HPC platforms, you can use an `sbatch` run script such as the following (here `N=4`):

Expand All @@ -134,11 +133,6 @@ For the SLURM scheduler, which is widely used on HPC platforms, you can use an `

export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK

TMP_COMPILEDIR=$(mktemp -d)
export THEANO_FLAGS="blas.ldflags=,base_compiledir=$TMP_COMPILEDIR"

trap 'rm -r $TMP_COMPILEDIR' EXIT

srun attrici \
detrend \
--gmt-file ./tests/data/20CRv3-ERA5_germany_ssa_gmt.nc \
Expand Down Expand Up @@ -175,10 +169,6 @@ If you prefer SLURM tasks rather than job arrays, an example scheduling script w
export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK

srun bash <<'EOF'
TMP_COMPILEDIR=$(mktemp -d)
export THEANO_FLAGS="blas.ldflags=,base_compiledir=$TMP_COMPILEDIR"

trap 'rm -r $TMP_COMPILEDIR' EXIT

exec attrici \
detrend \
Expand Down
7 changes: 7 additions & 0 deletions attrici/commands/detrend.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,13 @@ def add_parser(subparsers):
default=Config.__dataclass_fields__["seed"].default,
help="Seed for deterministic randomisation",
)
group.add_argument(
"--solver",
type=str,
choices=["pymc5"],
default="pymc5",
help="Solver library for statistical modelling",
)
group.add_argument(
"--start-date",
type=iso_date,
Expand Down
23 changes: 13 additions & 10 deletions attrici/detrend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from loguru import logger

from attrici import variables
from attrici.estimation.model_pymc3 import ModelPymc3
from attrici.util import get_data_provenance_metadata, timeit

MODEL_FOR_VAR = {
Expand Down Expand Up @@ -54,6 +53,8 @@ class Config:
"""List of variables to include in the output """
seed: int = 0
"""Seed for deterministic randomisation"""
solver: str = "pymc5"
"""Solver library for statistical modelling"""
start_date: date | None = None
"""Optional start date YYYY-MM-DD"""
stop_date: date | None = None
Expand Down Expand Up @@ -199,14 +200,7 @@ def save_trace(trace, filename):
) # TODO use a different format than pickle


def detrend_cell(
config,
data,
gmt_scaled,
subset_times,
lat,
lon,
):
def detrend_cell(config, data, gmt_scaled, subset_times, lat, lon, model_class):
output_filename = (
Path(config.output_dir)
/ "timeseries"
Expand All @@ -229,8 +223,9 @@ def detrend_cell(
data[np.isinf(data)] = np.nan

variable = MODEL_FOR_VAR[config.variable](data)

statistical_model = variable.create_model(
ModelPymc3, gmt_scaled.sel(time=subset_times), config.modes
model_class, gmt_scaled.sel(time=subset_times), config.modes
)

trace = None
Expand Down Expand Up @@ -384,6 +379,13 @@ def detrend(config: Config):
(obs_data.time >= startdate) & (obs_data.time <= stopdate)
]

if config.solver == "pymc5":
from attrici.estimation.model_pymc5 import ModelPymc5

model_class = ModelPymc5
else:
raise ValueError(f"Unknown solver {config.solver}")

for lat, lat_index, lon, lon_index in indices:
logger.info(
"This is task {} working on lat,lon {},{}", config.task_id, lat, lon
Expand All @@ -396,4 +398,5 @@ def detrend(config: Config):
subset_times,
lat,
lon,
model_class,
)
8 changes: 6 additions & 2 deletions attrici/estimation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@
from typing import Callable


class Parameter:
pass


class AttriciGLM:
@dataclass
class PredictorDependentParam:
class PredictorDependentParam(Parameter):
link: Callable
modes: int

@dataclass
class PredictorIndependentParam:
class PredictorIndependentParam(Parameter):
link: Callable
modes: int

Expand Down
Loading
Loading