Skip to content

Commit

Permalink
apply suggestions from review
Browse files Browse the repository at this point in the history
  • Loading branch information
BradyPlanden committed Sep 9, 2024
1 parent 04d1d19 commit 9018f75
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
14 changes: 12 additions & 2 deletions pybop/costs/_likelihoods.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,17 @@ class LogPosterior(BaseLikelihood):
Computes the log posterior which is proportional to the sum of the log
likelihood and the log prior.
Inherits all parameters and attributes from ``BaseLikelihood``.
Parameters
----------
log_likelihood : BaseLikelihood
The likelihood class of type ``BaseLikelihood``.
log_prior : Optional, Union[pybop.BasePrior, stats.rv_continuous]
The prior class of type ``BasePrior`` or ``stats.rv_continuous``.
If not provided, the prior class will be taken from the parameter priors
constructed in the `pybop.Parameters` class.
gradient_step : float, default: 1e-3
The step size for the finite-difference gradient calculation
if the ``log_prior`` is not of type ``BasePrior``.
"""

def __init__(
Expand Down Expand Up @@ -260,7 +270,7 @@ def compute(
self.verify_args(dy, calculate_grad)

if calculate_grad:
if hasattr(self._prior, "logpdfS1"):
if isinstance(self._prior, BasePrior):
log_prior, dp = self._prior.logpdfS1(self.parameters.current_value())
else:
# Compute log prior first
Expand Down
6 changes: 3 additions & 3 deletions pybop/plotting/plot2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
from scipy.interpolate import griddata

from pybop import BaseOptimiser, Optimisation, PlotlyManager
from pybop import BaseCost, BaseOptimiser, Optimisation, PlotlyManager


def plot2d(
Expand Down Expand Up @@ -64,11 +64,11 @@ def plot2d(
cost = cost_or_optim
plot_optim = False

if hasattr(cost, "parameters") and len(cost.parameters) < 2:
if isinstance(cost, BaseCost) and len(cost.parameters) < 2:
raise ValueError("This cost function takes fewer than 2 parameters.")

additional_values = []
if hasattr(cost, "parameters") and len(cost.parameters) > 2:
if isinstance(cost, BaseCost) and len(cost.parameters) > 2:
warnings.warn(
"This cost function requires more than 2 parameters. "
"Plotting in 2d with fixed values for the additional parameters.",
Expand Down

0 comments on commit 9018f75

Please sign in to comment.