diff --git a/pymc_marketing/mmm/transformers.py b/pymc_marketing/mmm/transformers.py index 58fa6c50b..405ae94a4 100644 --- a/pymc_marketing/mmm/transformers.py +++ b/pymc_marketing/mmm/transformers.py @@ -924,47 +924,50 @@ def hill_saturation( .. plot:: :context: close-figs + import numpy as np import matplotlib.pyplot as plt from pymc_marketing.mmm.transformers import hill_saturation x = np.linspace(0, 10, 100) # Varying sigma sigmas = [0.5, 1, 1.5] - plt.figure(figsize=(12, 4)) + fig, axes = plt.subplots(1, 3, figsize=(12, 4), sharey=True) for i, sigma in enumerate(sigmas): plt.subplot(1, 3, i+1) - y = hill_saturation(x, sigma, 2, 5) + y = hill_saturation(x, sigma, 2, 5).eval() plt.plot(x, y) plt.xlabel('x') - plt.ylabel('Hill Saturation') plt.title(f'Sigma = {sigma}') + plt.subplot(1,3,1) + plt.ylabel('Hill Saturation') plt.tight_layout() plt.show() # Varying beta betas = [1, 2, 3] - plt.figure(figsize=(12, 4)) + fig, axes = plt.subplots(1, 3, figsize=(12, 4), sharey=True) for i, beta in enumerate(betas): plt.subplot(1, 3, i+1) - y = hill_saturation(x, 1, beta, 5) + y = hill_saturation(x, 1, beta, 5).eval() plt.plot(x, y) plt.xlabel('x') - plt.ylabel('Hill Saturation') plt.title(f'Beta = {beta}') + plt.subplot(1,3,1) + plt.ylabel('Hill Saturation') plt.tight_layout() plt.show() # Varying lam lams = [3, 5, 7] - plt.figure(figsize=(12, 4)) + fig, axes = plt.subplots(1, 3, figsize=(12, 4), sharey=True) for i, lam in enumerate(lams): plt.subplot(1, 3, i+1) - y = hill_saturation(x, 1, 2, lam) + y = hill_saturation(x, 1, 2, lam).eval() plt.plot(x, y) plt.xlabel('x') - plt.ylabel('Hill Saturation') plt.title(f'Lambda = {lam}') + plt.subplot(1,3,1) + plt.ylabel('Hill Saturation') plt.tight_layout() plt.show() - Parameters ---------- x : float or array-like