Interactive MMM Visualizations with plot_interactive#
This notebook demonstrates the interactive plotting capabilities of the
plot_interactive module in pymc-marketing. These Plotly-based visualizations
are designed for exploring MMM results interactively — hovering over data points,
zooming into time ranges, and faceting across custom dimensions like markets.
We use the multidimensional MMM example data (two geos: geo_a and geo_b,
two channels: x1 and x2) and fit the model using the same setup as the
MMM Multidimensional Example Notebook.
What this notebook covers:
Posterior predictive checks (actual vs predicted)
ROAS analysis at different time granularities
Channel contributions over time
Saturation curves (diminishing returns)
Adstock curves (carryover effects)
Auto-faceting across custom dimensions
Advanced: Filtering and aggregating data with
MMMSummaryFactoryandMMMPlotlyFactory
import warnings
import numpy as np
import pandas as pd
import plotly.io as pio
from pymc_extras.prior import Prior
from pymc_marketing.mmm import GeometricAdstock, LogisticSaturation
from pymc_marketing.mmm.multidimensional import MMM
from pymc_marketing.paths import data_dir
from pymc_marketing.special_priors import LaplacePrior, LogNormalPrior
warnings.filterwarnings("ignore", category=UserWarning)
pio.renderers.default = "notebook_connected"
seed: int = sum(map(ord, "plot_interactive"))
rng: np.random.Generator = np.random.default_rng(seed=seed)
Load Data & Fit Model#
We load the same simulated multidimensional dataset used in the
MMM Multidimensional Example.
The data has two geographies (geo_a, geo_b) and two channels (x1, x2).
For model fitting details (prior specification, adstock/saturation choices, pooling strategies), please refer to the original notebook. Here we set up and fit the model quickly so we can focus on interactive visualizations.
data_path = data_dir / "mmm_multidimensional_example.csv"
data_df = pd.read_csv(data_path, parse_dates=["date"])
data_df.head()
| date | geo | x1 | x2 | event_1 | event_2 | y | |
|---|---|---|---|---|---|---|---|
| 0 | 2022-06-06 | geo_a | 5527.640078 | 0.000000 | 0 | 0 | 2647.596355 |
| 1 | 2022-06-06 | geo_b | 8849.257500 | 8063.918386 | 0 | 0 | 682.406280 |
| 2 | 2022-06-13 | geo_a | 6692.655692 | 0.000000 | 0 | 0 | 5020.823907 |
| 3 | 2022-06-13 | geo_b | 9073.817994 | 9354.014585 | 0 | 0 | 3753.104897 |
| 4 | 2022-06-20 | geo_a | 7124.016733 | 0.000000 | 0 | 0 | 6184.322132 |
# --- Prior Specification ---
# Hierarchical beta (partially pooled across geos)
beta_prior = LogNormalPrior(
mean=Prior("Gamma", mu=0.25, sigma=0.10, dims=("channel")),
std=Prior("Exponential", scale=0.10, dims=("channel")),
dims=("channel", "geo"),
centered=False,
)
# Saturation: beta is hierarchical, lambda is fully pooled
saturation = LogisticSaturation(
priors={
"beta": beta_prior,
"lam": Prior("Gamma", mu=0.5, sigma=0.25, dims=("channel")),
}
)
# Adstock: unpooled (each geo x channel has its own alpha)
adstock = GeometricAdstock(
priors={"alpha": Prior("Beta", alpha=2, beta=5, dims=("geo", "channel"))},
l_max=8,
)
# Model config
model_config = {
"intercept": Prior("Gamma", mu=0.5, sigma=0.25, dims="geo"),
"gamma_control": Prior("Normal", mu=0, sigma=0.5, dims="control"),
"gamma_fourier": LaplacePrior(
mu=0,
b=Prior("HalfNormal", sigma=0.2),
dims=("geo", "fourier_mode"),
centered=False,
),
"likelihood": Prior(
"TruncatedNormal",
lower=0,
sigma=Prior("HalfNormal", sigma=1.5),
dims=("date", "geo"),
),
}
# --- Model Definition ---
mmm = MMM(
date_column="date",
target_column="y",
channel_columns=["x1", "x2"],
control_columns=["event_1", "event_2"],
dims=("geo",),
scaling={
"channel": {"method": "max", "dims": ()},
"target": {"method": "max", "dims": ()},
},
adstock=adstock,
saturation=saturation,
yearly_seasonality=2,
model_config=model_config,
)
# --- Fit ---
x_train = data_df.drop(columns=["y"])
y_train = data_df["y"]
mmm.fit(
X=x_train,
y=y_train,
chains=4,
target_accept=0.95,
random_seed=rng,
)
# --- Add original scale deterministic variables ---
mmm.build_model(X=x_train, y=y_train)
mmm.add_original_scale_contribution_variable(
var=[
"channel_contribution",
"control_contribution",
"intercept_contribution",
"yearly_seasonality_contribution",
"y",
]
)
# --- Posterior predictive ---
mmm.sample_posterior_predictive(X=x_train, random_seed=rng)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [intercept_contribution, adstock_alpha, saturation_lam, saturation_beta_mean, saturation_beta_std, saturation_beta_log_offset, gamma_control, gamma_fourier_b, gamma_fourier_sigma, gamma_fourier_offset, y_sigma]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 16 seconds.
Sampling: [y]
<xarray.Dataset> Size: 20MB
Dimensions: (date: 159, geo: 2, sample: 4000)
Coordinates:
* date (date) datetime64[ns] 1kB 2022-06-06 ... 2025-06-16
* geo (geo) <U5 40B 'geo_a' 'geo_b'
* sample (sample) object 32kB MultiIndex
* chain (sample) int64 32kB 0 0 0 0 0 0 0 0 0 ... 3 3 3 3 3 3 3 3
* draw (sample) int64 32kB 0 1 2 3 4 5 ... 995 996 997 998 999
Data variables:
y (date, geo, sample) float64 10MB 0.2945 0.05857 ... 0.2187
y_original_scale (date, geo, sample) float64 10MB 4.068e+03 ... 2.407e+03
Attributes:
created_at: 2026-02-10T13:45:17.470721+00:00
arviz_version: 0.23.4
inference_library: pymc
inference_library_version: 5.27.11. Posterior Predictive: How Well Does the Model Fit?#
The posterior predictive plot shows the model’s predictions (with uncertainty) against the observed data. This is the first thing to check after fitting — do the predictions track the actual sales?
The interactive plot lets you hover over any point to see exact values, and zoom into specific time periods.
mmm.plot_interactive.posterior_predictive()
2. ROAS Analysis: Which Channels Give the Best Return?#
ROAS (Return on Ad Spend) is one of the most important metrics for marketers.
The plot_interactive module makes it easy to slice and dice ROAS across
different time granularities and dimensions.
Q: How did the ROAS of each channel change year after year?#
mmm.plot_interactive.roas(
frequency="yearly",
color="date",
x="channel",
)
Q: Within each year, which channel performed better in each geo?#
By swapping x and color, we get a different perspective — now the x-axis
shows time and the color differentiates channels.
mmm.plot_interactive.roas(
frequency="yearly",
color="channel",
x="date",
)
Q: Looking over all the data, which channel performed better in each geo?#
Using frequency="all_time" aggregates everything into a single time period,
giving us the overall ROAS per channel per geo.
mmm.plot_interactive.roas(frequency="all_time")
3. Channel Contributions: What Drives Sales?#
The contributions plot shows how much each channel contributes to total sales
over time. It supports the same frequency, color, and x parameters as
roas() above, so you can slice contributions in exactly the same way —
by time granularity, channel, or geography.
Q: What is the overall contribution of each channel across all geos?#
Error bars show the 94% HDI (Highest Density Interval).
mmm.plot_interactive.contributions(frequency="all_time")
Q: How did channel contributions change year by year?#
Just like the ROAS examples, setting frequency="yearly", color="channel",
and x="date" gives us a time-series view colored by channel.
mmm.plot_interactive.contributions(
frequency="yearly",
color="channel",
x="date",
)
Q: How did the contributions of control variables change year by year?#
On top of channel contributions, we can also plot the contributions of control, seasonality, and baseline variables.
mmm.plot_interactive.contributions(
component="control",
frequency="yearly",
color="control",
x="date",
hdi_prob=None,
)
4. Saturation Curves: Where Are Diminishing Returns?#
Saturation curves show how the response (contribution) changes as spend increases. These are essential for understanding where additional spend will have diminishing returns.
Q: Show me the saturation curves in original scale#
By default, the x-axis is in original scale (e.g., dollars of spend). Each channel gets its own line, faceted by geo.
mmm.plot_interactive.saturation_curves()
Sampling: []
Q: Now show me with uncertainty (HDI bands)#
Adding hdi_prob=0.9 draws shaded bands showing the 90% HDI around
each curve — capturing posterior uncertainty in the saturation parameters.
mmm.plot_interactive.saturation_curves(hdi_prob=0.9)
Sampling: []
5. Adstock Curves: How Long Do Effects Last?#
Adstock (carryover) curves show how the effect of a marketing impulse decays over time. A slow decay means the channel has long-lasting effects; a fast decay means the effect is short-lived.
Q: How do the decay curves look?#
mmm.plot_interactive.adstock_curves(hdi_prob=None)
Sampling: []
Q: Show me adstock curves with uncertainty#
Adding HDI bands helps us understand how confident we are about the carryover duration for each channel.
mmm.plot_interactive.adstock_curves(hdi_prob=0.9)
Sampling: []
6. Auto-faceting Across Custom Dimensions#
When your model includes custom dimensions (like geo in our example),
plot_interactive automatically creates subplots (facets) for each
dimension value. You’ve already seen this in all the plots above — each
subplot corresponds to a different geography (geo_a and geo_b).
This behavior is controlled by the auto_facet parameter, which is
enabled by default. You can also control the faceting layout with
facet_col and facet_row:
facet_col— creates side-by-side columns (one per dimension value)facet_row— stacks subplots vertically (one per dimension value)single_dim_facet— controls the default direction ("col"or"row")
Q: Compare saturation curves across geos stacked vertically#
Using facet_row="geo" overrides the default column layout and stacks
the saturation curves vertically instead.
mmm.plot_interactive.saturation_curves(
facet_row="geo",
)
Sampling: []
Q: What if I want to have all lines in the same plot?#
Setting auto_facet=False disables the automatic subplot creation.
For line-based plots (like saturation and adstock curves), the custom
dimension is then shown using line dash styles instead of separate
subplots — all curves appear on a single plot, differentiated by dashing.
mmm.plot_interactive.saturation_curves(
auto_facet=False,
)
Sampling: []
Q: Show adstock curves with geos as columns instead of the default#
By passing single_dim_facet="col", the single custom dimension (geo)
is faceted as columns rather than the default rows for this plot type.
mmm.plot_interactive.adstock_curves(
single_dim_facet="col",
)
Sampling: []
7. Advanced: Filtering & Aggregating with the Data Layer#
This is an advanced section. The examples above cover the most common use cases through
mmm.plot_interactive. This section shows how to work directly with the underlying components for custom data slicing.
Under the hood, mmm.plot_interactive is powered by two key classes:
MMMSummaryFactory— Takes a data wrapper (frommmm.data) and the fitted model, and computes summary statistics like contributions, ROAS, and posterior predictive values. It handles HDI computation, time aggregation, and proper scaling. You can think of it as the “data engine” that turns raw InferenceData into plottable DataFrames.MMMPlotlyFactory— Takes anMMMSummaryFactoryand provides all the interactive plotting methods (posterior_predictive(),roas(),contributions(),saturation_curves(),adstock_curves()). It reads from the summary factory and creates Plotly figures.
When you call mmm.plot_interactive, it automatically creates both of
these using your full dataset. But you can also create them manually
with filtered or aggregated data — this is the key to custom views.
The workflow is:
Transform the data using
mmm.data.filter_dims(),mmm.data.filter_dates(), ormmm.data.aggregate_dims()Create a new
MMMSummaryFactorywith the transformed dataCreate a new
MMMPlotlyFactorywith that summaryPlot using the factory’s methods
Q: Aggregating over geos, how did the ROAS of each channel change year after year?#
Here we aggregate both geos into a single “all_geos” label, then plot ROAS. This collapses the geo dimension so we get a single aggregated view.
from pymc_marketing.mmm.plot_interactive import MMMPlotlyFactory
from pymc_marketing.mmm.summary import MMMSummaryFactory
# Aggregate geo_a and geo_b into "all_geos"
agg_data = mmm.data.aggregate_dims(
dim="geo", values=["geo_a", "geo_b"], new_label="all_geos"
)
agg_summary = MMMSummaryFactory(agg_data, mmm)
agg_factory = MMMPlotlyFactory(summary=agg_summary)
agg_factory.roas(
frequency="yearly",
color="channel",
x="date",
)
Q: Filtering to only one geo, what was the yearly ROAS?#
You can filter the data to a single geography and create a dedicated factory.
filtered_data_geo_a = mmm.data.filter_dims(geo="geo_a")
filtered_summary_geo_a = MMMSummaryFactory(
filtered_data_geo_a, mmm, validate_data=False
)
filtered_factory_geo_a = MMMPlotlyFactory(summary=filtered_summary_geo_a)
filtered_factory_geo_a.roas(
frequency="yearly",
color="channel",
x="date",
title="ROAS for geo_a (yearly)",
)
Q: How did the ROAS change quarter after quarter starting 2024?#
Filter by date range and then view quarterly ROAS.
filtered_data_2024 = mmm.data.filter_dates(start_date="2024-01-01")
filtered_summary_2024 = MMMSummaryFactory(filtered_data_2024, mmm)
filtered_factory_2024 = MMMPlotlyFactory(summary=filtered_summary_2024)
filtered_factory_2024.roas(
frequency="quarterly",
color="channel",
x="date",
hdi_prob=None,
title="ROAS from 2024 onwards (quarterly)",
)
Summary#
The plot_interactive module provides a rich set of interactive visualizations for
exploring MMM results:
Method |
What It Shows |
Key Parameters |
|---|---|---|
|
Actual vs predicted with HDI band |
|
|
Channel/control/seasonality contributions |
|
|
Return on Ad Spend |
|
|
Diminishing returns curves |
|
|
Carryover effect curves |
|
Key features:
Auto-faceting: Custom dimensions (e.g., geo) automatically create subplots
Facet control: Use
facet_col,facet_row, andauto_facetto customize layoutFiltering: Use
mmm.data.filter_dims()orfilter_dates()to focus on subsetsAggregating: Use
mmm.data.aggregate_dims()to combine dimensionsError bars: Control with
hdi_prob(set toNoneto remove)Customizable: All Plotly Express kwargs (title, height, width, colors, etc.) are supported
%load_ext watermark
%watermark -n -u -v -iv -w
Last updated: Tue, 10 Feb 2026
Python implementation: CPython
Python version : 3.13.12
IPython version : 9.10.0
numpy : 2.3.5
pandas : 2.3.3
plotly : 6.5.2
pymc_extras : 0.8.0
pymc_marketing: 0.18.0
Watermark: 2.6.0