From 12a5b07556ea88934522a6db65d22f30fcda9fef Mon Sep 17 00:00:00 2001 From: Patrick Robotham Date: Thu, 25 Jul 2024 15:21:04 +1000 Subject: [PATCH] Fix Visual for hill_saturation function (Issue #851 ) (#857) * Fix plotting by evaluating tensors. * Add space after sphinx directive. * Remove indentation from blank line. * Add shared y axis for subplots. --------- Co-authored-by: Patrick Robotham Co-authored-by: Will Dean <57733339+wd60622@users.noreply.github.com> --- pymc_marketing/mmm/transformers.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/pymc_marketing/mmm/transformers.py b/pymc_marketing/mmm/transformers.py index 58fa6c50b..405ae94a4 100644 --- a/pymc_marketing/mmm/transformers.py +++ b/pymc_marketing/mmm/transformers.py @@ -924,47 +924,50 @@ def hill_saturation( .. plot:: :context: close-figs + import numpy as np import matplotlib.pyplot as plt from pymc_marketing.mmm.transformers import hill_saturation x = np.linspace(0, 10, 100) # Varying sigma sigmas = [0.5, 1, 1.5] - plt.figure(figsize=(12, 4)) + fig, axes = plt.subplots(1, 3, figsize=(12, 4), sharey=True) for i, sigma in enumerate(sigmas): plt.subplot(1, 3, i+1) - y = hill_saturation(x, sigma, 2, 5) + y = hill_saturation(x, sigma, 2, 5).eval() plt.plot(x, y) plt.xlabel('x') - plt.ylabel('Hill Saturation') plt.title(f'Sigma = {sigma}') + plt.subplot(1,3,1) + plt.ylabel('Hill Saturation') plt.tight_layout() plt.show() # Varying beta betas = [1, 2, 3] - plt.figure(figsize=(12, 4)) + fig, axes = plt.subplots(1, 3, figsize=(12, 4), sharey=True) for i, beta in enumerate(betas): plt.subplot(1, 3, i+1) - y = hill_saturation(x, 1, beta, 5) + y = hill_saturation(x, 1, beta, 5).eval() plt.plot(x, y) plt.xlabel('x') - plt.ylabel('Hill Saturation') plt.title(f'Beta = {beta}') + plt.subplot(1,3,1) + plt.ylabel('Hill Saturation') plt.tight_layout() plt.show() # Varying lam lams = [3, 5, 7] - plt.figure(figsize=(12, 4)) + fig, axes = plt.subplots(1, 3, figsize=(12, 4), sharey=True) for i, lam in enumerate(lams): plt.subplot(1, 3, i+1) - y = hill_saturation(x, 1, 2, lam) + y = hill_saturation(x, 1, 2, lam).eval() plt.plot(x, y) plt.xlabel('x') - plt.ylabel('Hill Saturation') plt.title(f'Lambda = {lam}') + plt.subplot(1,3,1) + plt.ylabel('Hill Saturation') plt.tight_layout() plt.show() - Parameters ---------- x : float or array-like