Skip to content

Commit

Permalink
More flexible type for BudgetOptimizer (#1429)
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 authored Jan 24, 2025
1 parent 9994ee9 commit 9f4e76f
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions pymc_marketing/mmm/budget_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@

import warnings
from collections.abc import Sequence
from typing import Any, ClassVar
from typing import Any, ClassVar, Protocol, runtime_checkable

import arviz as az
import numpy as np
import pytensor.tensor as pt
import xarray as xr
from pydantic import BaseModel, ConfigDict, Field
from arviz import InferenceData
from pydantic import BaseModel, ConfigDict, Field, InstanceOf
from pymc import Model, do
from pymc.logprob.utils import rvs_in_graph
from pymc.model.transform.optimization import freeze_dims_and_data
Expand All @@ -39,7 +40,6 @@
build_default_sum_constraint,
compile_constraints_for_scipy,
)
from pymc_marketing.mmm.mmm import MMM
from pymc_marketing.mmm.utility import UtilityFunctionType, average_response


Expand Down Expand Up @@ -83,6 +83,18 @@ def __init__(self, message: str):
super().__init__(message)


@runtime_checkable
class OptimizerCompatibleModelWrapper(Protocol):
"""Protocol for marketing mix model wrappers compatible with the BudgetOptimizer."""

adstock: Any
_channel_scales: Any
idata: InferenceData

def _set_predictors_for_optimization(self, num_periods: int) -> Model:
"""Set the predictors for optimization."""


class BudgetOptimizer(BaseModel):
"""A class for optimizing budget allocation in a marketing mix model.
Expand All @@ -109,7 +121,7 @@ class BudgetOptimizer(BaseModel):
description="Number of time units at the desired time granularity to allocate budget for.",
)

mmm_model: MMM = Field(
mmm_model: InstanceOf[OptimizerCompatibleModelWrapper] = Field(
...,
description="The marketing mix model to optimize.",
arbitrary_types_allowed=True,
Expand Down

0 comments on commit 9f4e76f

Please sign in to comment.