Skip to content

Commit

Permalink
change: all plot return a fig, ax objects if the user sets show_plot=…
Browse files Browse the repository at this point in the history
…False
  • Loading branch information
givasile committed Feb 17, 2025
1 parent 63886fa commit 80d2d38
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 26 deletions.
10 changes: 1 addition & 9 deletions .github/workflows/publish_to_pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,4 @@ jobs:
run: python -m build

- name: Publish to PyPI
uses: pypa/gh-action-pypi-publish@release/v1

- name: Create GitHub Release
if: startsWith(github.ref, 'refs/tags/v') && contains(github.ref, '.')
uses: softprops/action-gh-release@v1
with:
files: dist/*
env:
GITHUB_TOKEN: ${{ secrets.EFFECTOR_GITHUB_API_KEY }}
uses: pypa/gh-action-pypi-publish@release/v1
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
# Changelog

## [Unreleased]

### Changed

- all plots return a `fig, ax` tuple, if the user wants to modify the plot further.
- default plot titles now display full method name, e.g., `Accumulated Local Effects` instead of `ALE`.
- changed README.md to reflect the new changes.
- license

### Added

- codecov badge to README.md

## [0.1.1] - 2025-02-17

### Changed
Expand Down
22 changes: 15 additions & 7 deletions docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,38 @@

---

## Copy changelog
## TODO before publishing

### Copy changelog

```bash
cp ../CHANGELOG.md docs/changelog.md
```

---

### Update the notebooks

If you want to clean everything first:
#### (Optional) Remove the old notebooks

```bash
rm -rf docs/notebooks
mkdir docs/notebooks
```

To update the docs with the latest tutorials:
#### Convert the notebooks to markdown

```bash
jupyter nbconvert --to markdown ./../notebooks/real-examples/* --output-dir docs/notebooks/real-examples/
jupyter nbconvert --to markdown ./../notebooks/synthetic-examples/* --output-dir docs/notebooks/synthetic-examples/
jupyter nbconvert --to markdown ./../notebooks/quickstart/* --output-dir docs/notebooks/quickstart/
jupyter nbconvert --to markdown ./../notebooks/guides/* --output-dir docs/notebooks/guides/
```
---

Then copy some on the static folder:
### Update the images on the `./static` folder

First create dir, if it does not exist:
#### Create the folders if they don't exist

```bash
mkdir docs/static/quickstart/
Expand All @@ -37,15 +43,17 @@ mkdir docs/static/quickstart/flexible_api_files/
mkdir docs/static/real-examples/01_bike_sharing_dataset_files/
```

Then copy the files:
#### Copy the images

```bash
cp docs/notebooks/quickstart/simple_api_files/* docs/static/quickstart/simple_api_files/
cp docs/notebooks/quickstart/flexible_api_files/* docs/static/quickstart/flexible_api_files/
cp docs/notebooks/real-examples/01_bike_sharing_dataset_files/* docs/static/real-examples/01_bike_sharing_dataset_files/
```

To launch the documentation locally:
---

## Run the documentation locally

```zsh
mkdocs serve
Expand Down
13 changes: 12 additions & 1 deletion effector/global_effect_ale.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def plot(
y_limits: Optional[List] = None,
dy_limits: Optional[List] = None,
show_only_aggregated: bool = False,
show_plot: bool = True,
):
"""
Plot the (RH)ALE feature effect of feature `feature`.
Expand Down Expand Up @@ -206,6 +207,9 @@ def plot(
- If set to None, the limits of the dy-axis are set automatically
- If set to a tuple, the limits are manually set
show_only_aggregated: if True, only the main ale plot will be shown
show_plot: if True, the plot will be shown
"""
heterogeneity = helpers.prep_confidence_interval(heterogeneity)
centering = helpers.prep_centering(centering)
Expand All @@ -222,7 +226,7 @@ def plot(
else:
avg_output = None

vis.ale_plot(
ret = vis.ale_plot(
self.feature_effect["feature_" + str(feature)],
self.eval,
feature,
Expand All @@ -237,8 +241,15 @@ def plot(
y_limits=y_limits,
dy_limits=dy_limits,
show_only_aggregated=show_only_aggregated,
show_plot=show_plot,
)

if not show_plot:
fig, ax = ret
return fig, ax




class ALE(ALEBase):
def __init__(
Expand Down
22 changes: 19 additions & 3 deletions effector/global_effect_pdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def _plot(
show_avg_output: bool = False,
y_limits: Optional[List] = None,
use_vectorized: bool = True,
show_plot: bool = True,
):
heterogeneity = helpers.prep_confidence_interval(heterogeneity)
centering = helpers.prep_centering(centering)
Expand All @@ -212,7 +213,7 @@ def _plot(
avg_output = None

title = "PDP" if self.method_name == "pdp" else "d-PDP"
vis.plot_pdp_ice(
ret = vis.plot_pdp_ice(
x,
feature,
yy=yy,
Expand All @@ -227,7 +228,11 @@ def _plot(
target_name=self.target_name,
nof_ice=nof_ice,
y_limits=y_limits,
show_plot=show_plot,
)
if not show_plot:
fig, ax = ret
return fig, ax

class PDP(PDPBase):
def __init__(
Expand Down Expand Up @@ -330,6 +335,7 @@ def plot(
show_avg_output: bool = False,
y_limits: Optional[List] = None,
use_vectorized: bool = True,
show_plot: bool = True,
):
"""
Plot the feature effect.
Expand Down Expand Up @@ -370,7 +376,7 @@ def plot(
use_vectorized: whether to use the vectorized version of the PDP computation
"""
self._plot(
ret = self._plot(
feature,
heterogeneity,
centering,
Expand All @@ -382,6 +388,9 @@ def plot(
y_limits,
use_vectorized,
)
if not show_plot:
return ret



class DerPDP(PDPBase):
Expand Down Expand Up @@ -492,6 +501,7 @@ def plot(
show_avg_output: bool = False,
dy_limits: Optional[List] = None,
use_vectorized: bool = True,
show_plot: bool = True,
):
"""
Plot the feature effect.
Expand Down Expand Up @@ -531,8 +541,9 @@ def plot(
- If set to a tuple, the limits are manually set
use_vectorized: whether to use the vectorized version of the PDP computation
show_plot: whether to show the plot
"""
self._plot(
ret = self._plot(
feature,
heterogeneity,
centering,
Expand All @@ -543,8 +554,13 @@ def plot(
show_avg_output,
dy_limits,
use_vectorized,
show_plot,
)

if not show_plot:
fig, ax = ret
return fig, ax


def ice_non_vectorized(
model: callable,
Expand Down
9 changes: 7 additions & 2 deletions effector/global_effect_shap.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,8 @@ def plot(
show_avg_output: bool = False,
y_limits: Optional[List] = None,
only_shap_values: bool = False,
) -> None:
show_plot: bool = True,
) -> Union[Tuple, None]:
"""
Plot the SHAP Dependence Plot (SDP) of the s-th feature.
Expand All @@ -317,6 +318,7 @@ def plot(
show_avg_output: whether to show the average output of the model
y_limits: limits of the y-axis
only_shap_values: whether to plot only the shap values
show_plot: whether to show the plot
"""
heterogeneity = helpers.prep_confidence_interval(heterogeneity)

Expand Down Expand Up @@ -354,7 +356,7 @@ def plot(
else:
avg_output = None

vis.plot_shap(
ret = vis.plot_shap(
x,
y,
xx,
Expand All @@ -369,4 +371,7 @@ def plot(
target_name=self.target_name,
y_limits=y_limits,
only_shap_values=only_shap_values,
show_plot=show_plot,
)

return ret
24 changes: 20 additions & 4 deletions effector/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def ale_plot(
y_limits: typing.Union[None, tuple] = None,
dy_limits: typing.Union[None, tuple] = None,
show_only_aggregated: bool = False,
show_plot: bool = True,
):
"""
Expand Down Expand Up @@ -115,7 +116,13 @@ def ale_plot(
ax2.set_xlabel(x_name)
ax2.set_ylabel("dy/dx")

plt.show(block=False)
if show_plot:
plt.show(block=False)
else:
if show_only_aggregated:
return fig, ax1
else:
return fig, (ax1, ax2)


def ale_curve(ax1, x, y, avg_output=None):
Expand Down Expand Up @@ -158,6 +165,7 @@ def plot_pdp_ice(
is_derivative: bool = False,
nof_ice: typing.Union[str, int] = "all",
y_limits: typing.Union[None, tuple] = None,
show_plot: bool = True,
):

fig, ax = plt.subplots()
Expand Down Expand Up @@ -245,8 +253,11 @@ def plot_pdp_ice(
if y_limits is not None:
ax.set_ylim(y_limits[0], y_limits[1])

plt.show(block=False)
return fig, ax
if show_plot:
plt.show(block=False)
else:
return fig, ax



def plot_shap(
Expand All @@ -264,6 +275,7 @@ def plot_shap(
target_name: typing.Union[None, str] = None,
y_limits: typing.Union[None, tuple] = None,
only_shap_values: bool = False,
show_plot: bool = True,
):

fig, ax = plt.subplots()
Expand Down Expand Up @@ -313,4 +325,8 @@ def plot_shap(
ax.legend()
if y_limits is not None:
ax.set_ylim(y_limits[0], y_limits[1])
plt.show(block=False)

if show_plot:
plt.show(block=False)
else:
return fig, ax

0 comments on commit 80d2d38

Please sign in to comment.