Skip to content

Commit

Permalink
Fix Visual for hill_saturation function (Issue pymc-labs#851 ) (pymc…
Browse files Browse the repository at this point in the history
…-labs#857)

* Fix plotting by evaluating tensors.

* Add space after sphinx directive.

* Remove indentation from blank line.

* Add shared y axis for subplots.

---------

Co-authored-by: Patrick Robotham <[email protected]>
Co-authored-by: Will Dean <[email protected]>
  • Loading branch information
3 people authored and radiokosmos committed Sep 1, 2024
1 parent 1021e32 commit 12a5b07
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions pymc_marketing/mmm/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -924,47 +924,50 @@ def hill_saturation(
.. plot::
:context: close-figs
import numpy as np
import matplotlib.pyplot as plt
from pymc_marketing.mmm.transformers import hill_saturation
x = np.linspace(0, 10, 100)
# Varying sigma
sigmas = [0.5, 1, 1.5]
plt.figure(figsize=(12, 4))
fig, axes = plt.subplots(1, 3, figsize=(12, 4), sharey=True)
for i, sigma in enumerate(sigmas):
plt.subplot(1, 3, i+1)
y = hill_saturation(x, sigma, 2, 5)
y = hill_saturation(x, sigma, 2, 5).eval()
plt.plot(x, y)
plt.xlabel('x')
plt.ylabel('Hill Saturation')
plt.title(f'Sigma = {sigma}')
plt.subplot(1,3,1)
plt.ylabel('Hill Saturation')
plt.tight_layout()
plt.show()
# Varying beta
betas = [1, 2, 3]
plt.figure(figsize=(12, 4))
fig, axes = plt.subplots(1, 3, figsize=(12, 4), sharey=True)
for i, beta in enumerate(betas):
plt.subplot(1, 3, i+1)
y = hill_saturation(x, 1, beta, 5)
y = hill_saturation(x, 1, beta, 5).eval()
plt.plot(x, y)
plt.xlabel('x')
plt.ylabel('Hill Saturation')
plt.title(f'Beta = {beta}')
plt.subplot(1,3,1)
plt.ylabel('Hill Saturation')
plt.tight_layout()
plt.show()
# Varying lam
lams = [3, 5, 7]
plt.figure(figsize=(12, 4))
fig, axes = plt.subplots(1, 3, figsize=(12, 4), sharey=True)
for i, lam in enumerate(lams):
plt.subplot(1, 3, i+1)
y = hill_saturation(x, 1, 2, lam)
y = hill_saturation(x, 1, 2, lam).eval()
plt.plot(x, y)
plt.xlabel('x')
plt.ylabel('Hill Saturation')
plt.title(f'Lambda = {lam}')
plt.subplot(1,3,1)
plt.ylabel('Hill Saturation')
plt.tight_layout()
plt.show()
Parameters
----------
x : float or array-like
Expand Down

0 comments on commit 12a5b07

Please sign in to comment.