Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bambi dev version with bayeux give wrong posterior dims for hierarchical model (mixed and dropped dimensions) #800

Closed
danieltomasz opened this issue Apr 8, 2024 · 12 comments · Fixed by #803
Labels

Comments

@danieltomasz
Copy link

the bambi 0.13 (in python 3.11) gives expected results, but not git version (arviz 0.18 in both env)

import requests
import pandas as pd
from io import StringIO
import bambi as bmb

url = "https://raw.githubusercontent.com/crnolan/pyrba/main/data.txt"  # replace with your url
response = requests.get(url)
data = response.text

# Convert the string to a file-like object
data_io = StringIO(data)

# Read the data into a DataFrame
df = pd.read_table(data_io, delimiter=r"\s+")

print(df.head())
print(df.nunique())
model = bmb.Model("y ~  (1|subject) + (1|ROI)", df)
results = model.fit(
    tune=4000,
    draws=1000,
    chains=8,
    inference_method="numpyro_nuts",
    max_tree_depth=3,
)

The git version actually doesnt return hierarchical model
Data variables
1|subject ~ (chain, draw, subject__factor_dim)
1|ROI ~ (chain, draw, ROI__factor_dim)
are dropped
and the subject factor and ROI factor have mixed dims (when you compare it looking to print(df.nunique())
Dimensions: chain: 8draw: 1000subject__factor_dim: 21ROI__factor_dim: 124

az.plot_trace(results, figsize=(20, 35))
print(az.summary(results))
                 mean     sd  hdi_3%  hdi_97%  mcse_mean  mcse_sd  ess_bulk  \
1|ROI_sigma      0.155  0.019   0.122    0.192      0.004    0.003      29.0   
1|subject_sigma  0.071  0.018   0.040    0.094      0.006    0.004      15.0   
Intercept        0.151  0.050   0.053    0.236      0.017    0.012       9.0   
y_sigma          0.156  0.005   0.149    0.166      0.001    0.001      18.0   

                 ess_tail  r_hat  
1|ROI_sigma         208.0   1.20  
1|subject_sigma      10.0   1.51  
Intercept            21.0   2.65  
y_sigma              10.0   1.38  

while the 0.13


                    mean     sd  hdi_3%  hdi_97%  mcse_mean  mcse_sd  \
1|ROI[ACC]        -0.107  0.038  -0.175   -0.035      0.011    0.008   
1|ROI[LAmy/Hippo] -0.061  0.038  -0.127    0.013      0.011    0.008   
1|ROI[LCing]      -0.204  0.038  -0.275   -0.136      0.011    0.008   
1|ROI[LIFG]       -0.119  0.038  -0.185   -0.046      0.011    0.008   
1|ROI[LIPL]        0.078  0.038   0.012    0.149      0.011    0.008   
...                  ...    ...     ...      ...        ...      ...   
1|subject[HMN199]  0.033  0.031  -0.026    0.091      0.002    0.002   
1|subject[HMN201]  0.126  0.031   0.073    0.187      0.002    0.002   
1|subject_sigma    0.078  0.006   0.066    0.090      0.001    0.001   
Intercept          0.159  0.037   0.085    0.222      0.012    0.008   
y_sigma            0.154  0.002   0.150    0.158      0.000    0.000   

                   ess_bulk  ess_tail  r_hat  
1|ROI[ACC]             12.0      25.0   1.80  
1|ROI[LAmy/Hippo]      12.0      27.0   1.81  
1|ROI[LCing]           12.0      23.0   1.83  
1|ROI[LIFG]            12.0      24.0   1.79  
1|ROI[LIPL]            12.0      26.0   1.84  
...                     ...       ...    ...  
1|subject[HMN199]     164.0     398.0   1.06  
1|subject[HMN201]     195.0     340.0   1.05  
1|subject_sigma        42.0     122.0   1.15  
Intercept              11.0      22.0   2.10  
y_sigma               111.0     305.0   1.07  

[149 rows x 9 columns]

another example will not even be able to return the posterior

import requests
import pandas as pd
from io import StringIO
import bambi as bmb

url = 'https://raw.githubusercontent.com/crnolan/pyrba/main/data.txt'  # replace with your url
response = requests.get(url)
data = response.text

# Convert the string to a file-like object
data_io = StringIO(data)

# Read the data into a DataFrame
df = pd.read_table(data_io, delimiter=r"\s+")

print(df.head())

model = bmb.Model("y ~ x + (1|subject) + (x|ROI)", df)
results = model.fit(
    tune=4000,
    draws=1000,
    chains=8,
    inference_method="numpyro_nuts",
    nuts_kwargs=dict(max_tree_depth=100),
)

Will give following error

    [177](https://file+.vscode-resource.vscode-cdn.net/Users/daniel/PhD/Projects/meg-assr-2023/notebooks/5a-pymc/~/.pyenv/versions/pyrba-3.12/lib/python3.12/site-packages/xarray/namedarray/utils.py:177)         f"{dims_supplied} must be a permuted list of {dims_all}, unless `...` is included"
    [178](https://file+.vscode-resource.vscode-cdn.net/Users/daniel/PhD/Projects/meg-assr-2023/notebooks/5a-pymc/~/.pyenv/versions/pyrba-3.12/lib/python3.12/site-packages/xarray/namedarray/utils.py:178)     )
    [179](https://file+.vscode-resource.vscode-cdn.net/Users/daniel/PhD/Projects/meg-assr-2023/notebooks/5a-pymc/~/.pyenv/versions/pyrba-3.12/lib/python3.12/site-packages/xarray/namedarray/utils.py:179) yield from existing_dims

ValueError: ('chain', 'draw', 'subject__factor_dim', 'ROI__factor_dim') must be a permuted list of FrozenMappingWarningOnValuesAccess({'chain': 8, 'draw': 1000, 'subject__factor_dim': 21, 'ROI__factor_dim': 124, 'x|ROI_offset_dim_0': 21}), unless `...` is included

the problem was discussed earlier here #799

@danieltomasz danieltomasz changed the title Git version with bayeux give wrong posterior dims for hierarchical model Git version with bayeux give wrong posterior dims for hierarchical model (mixed and dropped dimensions) Apr 8, 2024
@danieltomasz danieltomasz changed the title Git version with bayeux give wrong posterior dims for hierarchical model (mixed and dropped dimensions) bambi dev version with bayeux give wrong posterior dims for hierarchical model (mixed and dropped dimensions) Apr 8, 2024
@tomicapretto
Copy link
Collaborator

This is happening because bayeux does not include what PyMC calls deterministic variables (i.e. parameters that are determined by values of other parameters). PyMC now has pm.compute_deterministics() (pymc-devs/pymc#7238) and it may be of help in these cases. This is something we need to see how to handle internally.

For example, see the first model you shared, it makes uses of deterministics.
image

@ColCarroll
Copy link
Collaborator

I guess getting the suggestion here: jax-ml/bayeux#21 (comment) implemented would fix this?

I'm hopeful to have some bandwidth this week -- I'll add details to the linked issues in case someone else wants to make a PR though. (the details will be to copy what PyMC does, and open issues with PyMC to make this a public API so it is somewhat stable)

@tomicapretto
Copy link
Collaborator

Right now, I'm testing an implementation with pm.compute_deterministics(). It's very simple: if the inference data is obtained with bayeux, we use pm.compute_deterministics() and PyMC handles the logic for us. I'll keep you updated

@tomicapretto
Copy link
Collaborator

@danieltomasz can you install from the branch in this PR? #803

Your models should run

@danieltomasz
Copy link
Author

danieltomasz commented Apr 9, 2024

@tomicapretto when I try to run the code

model = bmb.Model("y ~  (1|subject) + (1|ROI)", df)
results = model.fit(
    tune=4000,
    draws=1000,
    chains=8,
    inference_method="numpyro_nuts",
    nuts_kwargs=dict(max_tree_depth=3),
)

I got

NotImplementedError: 'numpyro_nuts' method has not been implemented

My test env has numpyro 0.14.0

I installed it via conda (I had problem with pytensor on M1 installed via pip)

channels:
  - conda-forge
dependencies:
  - conda-forge::python=3.12.2
  - conda-forge::pytensor=2.20
  - conda-forge::pandas
  - conda-forge::ipykernel
  - conda-forge::pip
  - conda-forge::ipywidgets
  - pip:
    - git+https://github.com/tomicapretto/bambi.git@support_pymc_5_13

(edit: first time when I tried to install bambi from this branch I got bad version, second time is 0.13)

@danieltomasz
Copy link
Author

danieltomasz commented Apr 9, 2024

Sorry, for some reason conda ignored pip install, I will install the version from branch directly in jupyter and check

@danieltomasz
Copy link
Author

danieltomasz commented Apr 9, 2024

The result of
!pip install git+https://github.com/tomicapretto/bambi.git@support_pymc_5_13 should be the version of bambi v0.2.1.dev340+gb431d81 ?

it was really unexpected to see this version but also installing from latest commit yield the same version

I am still getting NotImplementedError: 'numpyro_nuts' method has not been implemented though

@danieltomasz
Copy link
Author

danieltomasz commented Apr 9, 2024

The reason was that bayeux-ml wasn't installed in my test env, I spent 30 min trying to debug it :P The error should be more informative, especially if this is breaking change and previous behaviour; otherwise bayes-ml could be added as main dependency but is not yet on conda-forge and currently is just optional I think

@danieltomasz
Copy link
Author

danieltomasz commented Apr 9, 2024

Also bambi 0.13 gives me warning
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details

when i set draws=1000, the dev version is silent with the same number of draws - the new implemention is just better or there is no check?

@tomicapretto
Copy link
Collaborator

@danieltomasz bayeux-ml is not installed by default. It works if you do pip install bambi[jax], which installs all the dependencies required to work with JAX-based samplers.

The version name is automatically generated. This is done on purpose. If we have 0.13dev for a while, we will end up having multiple 0.13dev versions with different versions of the code, which is not good. Also, this automatic versioning system ensures that the library is re-installed if you do install from the main branch whenever there's a new commit. If we use 0.13dev you have to force the re-installation (otherwise it doesn't re-install as pip sees you already have the downloaded version installed).

As for the r-hat stats, are you using the same random seed? It may be just bad luck. We have not changed the implementation.

@danieltomasz
Copy link
Author

danieltomasz commented Apr 10, 2024

Hi @tomicapretto thanks for the reply! Yes, I kind of figured out that this version is a some special way of marking, that way I deleted my previous comment before reading your reply;
Regarding bayeux-ml, I figured it twice that this is optional library, the second time took me a bit longer;

The motivation for my remark was more about better error message - the code worked with previous versions of Bambi (including 0.13) without bayeux-ml in virtual test environment, for someone who updates from older version it might be not super clear that bayes-ml should be installed ; I run Mac with M1 and only pytensor version from conda-forge works without errors, with conda cannot install bambi[jax],so I need to add optional dependcies manually

with the more specific error message saying I should install bayeux-ml in case I was trying to use old numpyro-nuts and it's not installed , I would get clue faster (or reminded myself what I learned before setting env to test bambi)

@tomicapretto
Copy link
Collaborator

@danieltomasz thanks for the suggestion, I really appreciate it. We're still preparing ourselves for a 0.14.0 release and I think before that we need to make sure users receive an informative message when they try to use a JAX-based sampler.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants