Market mix Model in python
mmm-lightweight
August 8, 2024
[2]: #!pip install lightweight_mmm
import lightweight_mmm
[3]: import jax.numpy as jnp
import numpy
import pandas as pd
[4]: from
from
from
from
from
lightweight_mmm
lightweight_mmm
lightweight_mmm
lightweight_mmm
lightweight_mmm
import
import
import
import
import
lightweight_mmm
utils
optimize_media
plot
preprocessing
[5]: import io
from google.colab import files
uploaded = files.upload()
Saving bike_sales_data.csv to bike_sales_data.csv
[6]: df = pd.read_csv('/content/bike_sales_data.csv')
[7]: df.head()
[7]:
0
1
2
3
4
Week
7/23/17
7/30/17
8/6/17
8/13/17
8/20/17
sales-
0
1
2
3
4
facebook_spend-
branded_search_spend-
print_spend
0
0
0
0
0
ooh_spend
0
0
0
0
0
1
nonbranded_search_spend-
tv_spend
0
0
0
0
0
radio_spend
0
0
0
0
0
\
[8]: df.drop('Week', inplace=True, axis=1)
[9]: import seaborn as sns
import matplotlib.pyplot as plt
plt.figure(figsize=(10,8))
sns.heatmap(df.corr(), annot=True)
[9]:
[10]: sns.pairplot(df)
[10]:
2
[11]: from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor
[12]: # X_rf = df.loc[:, df.columns != 'sales']
# y_rf = df['sales']
X_rf = df.loc[:,df.columns != 'sales']
y_rf = df['sales']
# Building Random Forest model
X_train_rf, X_test_rf, y_train_rf, y_test_rf = train_test_split(X_rf, y_rf,␣
↪test_size=.25, random_state=0)
model = RandomForestRegressor(random_state=1)
model.fit(X_train_rf, y_train_rf)
3
pred = model.predict(X_test_rf)
feat_importances = pd.Series(model.feature_importances_, index=X_rf.columns)
feat_importances.nlargest(25).plot(kind='barh',figsize=(10,10))
[12]:
[13]: media_df = df[['branded_search_spend',␣
↪'nonbranded_search_spend','facebook_spend', 'print_spend',␣
↪'ooh_spend','tv_spend', 'radio_spend']].to_numpy()
target_df = df['sales'].to_numpy()
costs = df[['branded_search_spend', 'nonbranded_search_spend','facebook_spend',␣
↪'print_spend', 'ooh_spend','tv_spend', 'radio_spend']].sum().to_numpy()
[14]: costs
[14]: array([-.
,-,-,
, 136000.
])
4
98000.
,
49000.
,
[14]:
[15]: split_point = media_df.shape[0] - 25
train_media = media_df[:split_point]
test_media = media_df[split_point:]
train_target = target_df[:split_point]
[ ]: !pip install jax
[16]: media_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
target_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
cost_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
media_data_train = media_scaler.fit_transform(train_media)
target_train = target_scaler.fit_transform(train_target)
costs2 = cost_scaler.fit_transform(costs)
[18]: #mmm = lightweight_mmm.LightweightMMM(model_name="carryover")
mmm = lightweight_mmm.LightweightMMM(model_name="hill_adstock")
[22]: number_warmup=1000
number_samples=1000
[23]: mmm.fit(
media=media_data_train,
media_prior=costs2,
target=target_train,
number_warmup=number_warmup,
number_samples=number_samples,
number_chains=1,
)
sample: 100%|����������| 2000/2000 [00:33<00:00, 59.89it/s, 114 steps of size
9.10e-03. acc. prob=0.96]
[24]: mmm.print_summary()
95.0%
n_eff
0.51
28.46
0.33
21.84
1.05
10.60
r_hat
coef_media[0]
1.02
coef_media[1]
1.00
coef_media[2]
1.24
coef_media[3]
mean
std
median
5.0%
0.41
0.08
0.40
0.28
0.27
0.05
0.27
0.19
0.54
0.34
0.46
0.06
0.21
0.14
0.18
0.01
5
0.41
8.69
1.02
coef_media[4]-
coef_media[5]-
coef_media[6]-
coef_trend[0]-
expo_trend-
gamma_seasonality[0,0]-
gamma_seasonality[0,1]-
gamma_seasonality[1,0]
-
gamma_seasonality[1,1]
-
half_max_effective_concentration[0]-
half_max_effective_concentration[1]-
half_max_effective_concentration[2]-
half_max_effective_concentration[3]-
half_max_effective_concentration[4]-
half_max_effective_concentration[5]-
half_max_effective_concentration[6]-
intercept[0]-
lag_weight[0]-
lag_weight[1]-
lag_weight[2]-
lag_weight[3]-
lag_weight[4]-
lag_weight[5]-
lag_weight[6]
0.01
0.01
0.01
0.00
0.36
0.14
0.30
0.18
0.26
0.06
0.26
0.16
-0.01
0.01
-0.01
-0.02
0.58
0.09
0.55
0.50
0.04
0.02
0.04
0.01
0.08
0.02
0.08
0.05
-0.03
0.01
-0.03
-0.05
-0.03
0.01
-0.03
-0.05
0.95
0.07
0.94
0.84
1.16
0.05
1.16
1.07
1.08
0.92
0.87
0.02
1.17
1.07
0.89
0.00
1.07
0.90
0.89
0.01
0.71
0.92
0.27
0.00
1.77
1.28
1.38
0.25
0.30
0.19
0.31
0.01
0.76
0.05
0.76
0.68
0.24
0.12
0.23
0.03
0.52
0.23
0.50
0.15
0.87
0.15
0.92
0.67
0.81
0.09
0.82
0.69
0.93
0.12
0.98
0.72
0.46
0.09
0.46
0.31
6
0.63
34.20
1.00
0.14
23.75
1.06
-
-
sigma[0]
0.13
0.01
0.13
0.12
slope[0]
6.19
1.55
6.32
3.47
slope[1]
9.47
2.52
9.10
5.22
slope[2]
0.41
0.51
0.20
0.00
slope[3]
0.62
1.11
0.07
0.00
slope[4]
0.84
0.72
0.60
0.02
slope[5]
1.02
1.40
0.08
0.00
slope[6]
0.65
0.33
0.61
0.15
1.02
1.00
1.16
21.85
1.01
2.48
10.29
1.04
1.96
24.74
1.05
3.12
3.76
1.74
1.13
33.69
1.00
Number of divergences: 998
[25]: plot.plot_media_channel_posteriors(media_mix_model=mmm)
/usr/local/lib/python3.10/dist-packages/arviz/utils.py:184:
NumbaDeprecationWarning: The 'nopython' keyword argument was not supplied to the
'numba.jit' decorator. The implicit default value for this argument is currently
False, but it will be changed to True in Numba 0.59.0. See
https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecationof-object-mode-fall-back-behaviour-when-using-jit for details.
numba_fn = numba.jit(**self.kwargs)(self.function)
/usr/local/lib/python3.10/dist-packages/arviz/stats/density_utils.py:957:
NumbaWarning:
Compilation is falling back to object mode WITH looplifting enabled because
Function "histogram" failed type inference due to: No implementation of function
Function() found for signature:
>>> histogram(readonly buffer(float32, 1d, C), bins=int64,
range=UniTuple(readonly buffer(float32, 0d, C) x 2))
There are 2 candidate implementations:
- Of which 2 did not match due to:
Overload in function 'np_histogram': File: numba/np/arraymath.py: Line 3912.
With argument(s): '(readonly buffer(float32, 1d, C), bins=int64,
range=UniTuple(readonly buffer(float32, 0d, C) x 2))':
Rejected as the implementation raised a specific error:
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function()
found for signature:
7
>>> linspace(readonly buffer(float32, 0d, C), readonly buffer(float32, 0d,
C), int64)
There are 2 candidate implementations:
- Of which 2 did not match due to:
Overload of function 'linspace': File: numba/np/arrayobj.py: Line 4780.
With argument(s): '(readonly buffer(float32, 0d, C), readonly
buffer(float32, 0d, C), int64)':
No match.
During: resolving callee type: Function()
During: typing of call at /usr/local/lib/python3.10/distpackages/numba/np/arraymath.py (3953)
File "../usr/local/lib/python3.10/dist-packages/numba/np/arraymath.py", line
3953:
def histogram_impl(a, bins=10, range=None):
bins_array = np.linspace(bin_min, bin_max, bins + 1)
^
raised from /usr/local/lib/python3.10/distpackages/numba/core/typeinfer.py:1086
During: resolving callee type: Function()
During: typing of call at /usr/local/lib/python3.10/distpackages/arviz/stats/density_utils.py (979)
File "../usr/local/lib/python3.10/dist-packages/arviz/stats/density_utils.py",
line 979:
def histogram(data, bins, range_hist=None):
"""
hist, bin_edges = np.histogram(data, bins=bins, range=range_hist)
^
@conditional_jit(cache=True)
/usr/local/lib/python3.10/dist-packages/numba/core/object_mode_passes.py:151:
NumbaWarning: Function "histogram" was compiled in object mode without
forceobj=True.
File "../usr/local/lib/python3.10/dist-packages/arviz/stats/density_utils.py",
line 958:
@conditional_jit(cache=True)
8
def histogram(data, bins, range_hist=None):
^
warnings.warn(errors.NumbaWarning(warn_msg,
/usr/local/lib/python3.10/dist-packages/numba/core/object_mode_passes.py:161:
NumbaDeprecationWarning:
Fall-back from the nopython compilation path to the object mode compilation path
has been detected. This is deprecated behaviour that will be removed in Numba
0.59.0.
For more information visit
https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecationof-object-mode-fall-back-behaviour-when-using-jit
File "../usr/local/lib/python3.10/dist-packages/arviz/stats/density_utils.py",
line 958:
@conditional_jit(cache=True)
def histogram(data, bins, range_hist=None):
^
warnings.warn(errors.NumbaDeprecationWarning(msg,
[25]:
9
10
[26]: plot.plot_model_fit(mmm, target_scaler=target_scaler)
/usr/local/lib/python3.10/dist-packages/arviz/utils.py:184:
NumbaDeprecationWarning: The 'nopython' keyword argument was not supplied to the
'numba.jit' decorator. The implicit default value for this argument is currently
False, but it will be changed to True in Numba 0.59.0. See
https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecationof-object-mode-fall-back-behaviour-when-using-jit for details.
numba_fn = numba.jit(**self.kwargs)(self.function)
[26]:
[27]: # We have to scale the test media data if we have not done so before.
new_predictions = mmm.predict(media=media_scaler.transform(test_media))
new_predictions.shape
[27]: (1000, 25)
[28]: plot.plot_out_of_sample_model_fit(out_of_sample_predictions=new_predictions,
out_of_sample_target=target_scaler.
↪transform(target_df[split_point:].squeeze()))
11
[28]:
[ ]: ## Media Insights
Media Insights
[29]: media_contribution, roi_hat = mmm.
↪get_posterior_metrics(target_scaler=target_scaler, cost_scaler=cost_scaler)
[30]: from matplotlib import pyplot as plt
import numpy as np
def custom_plot_media_baseline_contribution_area_plot(
media_mix_model,
target_scaler=None,
channel_names=None,
fig_size = (20, 7)):
"""Plots an area chart to visualize weekly media & baseline contribution.
Args:
media_mix_model: Media mix model.
target_scaler: Scaler used for scaling the target.
channel_names: Names of media channels.
12
fig_size: Size of the figure to plot as used by matplotlib.
Returns:
Stacked area chart of weekly baseline & media contribution.
"""
# Create media channels & baseline contribution dataframe.
contribution_df = plot.create_media_baseline_contribution_df(
media_mix_model=media_mix_model,
target_scaler=target_scaler,
channel_names=channel_names)
contribution_df = contribution_df.clip(0)
# Create contribution dataframe for the plot.
contribution_columns = [
col for col in contribution_df.columns if "contribution" in col
]
contribution_df_for_plot = contribution_df.loc[:, contribution_columns]
contribution_df_for_plot = contribution_df_for_plot[
contribution_df_for_plot.columns[::-1]]
period = np.arange(1, contribution_df_for_plot.shape[0] + 1)
contribution_df_for_plot.loc[:, "period"] = period
# Plot the stacked area chart.
fig, ax = plt.subplots()
contribution_df_for_plot.plot.area(
x="period", stacked=True, figsize=fig_size, ax=ax)
ax.set_title("Attribution Over Time", fontsize="x-large")
ax.tick_params(axis="y")
ax.set_ylabel("Baseline & Media Chanels Attribution")
ax.set_xlabel("Period")
ax.set_xlim(1, contribution_df_for_plot["period"].max())
ax.set_xticks(contribution_df_for_plot["period"])
ax.set_xticklabels(contribution_df_for_plot["period"])
for tick in ax.get_xticklabels():
tick.set_rotation(45)
plt.close()
return fig
[31]: custom_plot_media_baseline_contribution_area_plot(media_mix_model=mmm,
target_scaler=target_scaler,
fig_size=(30,10))
[31]:
13
[32]: plot.plot_bars_media_metrics(metric=media_contribution, metric_name="Media␣
↪Contribution Percentage")
[32]:
[33]: plot.plot_bars_media_metrics(metric=roi_hat, metric_name="ROI hat")
[33]:
14
[34]: plot.plot_response_curves(
media_mix_model=mmm, target_scaler=target_scaler)
[34]:
15
Optimization
[36]: prices = jnp.ones(mmm.n_media_channels)
16
[38]: n_time_periods = 10
budget = jnp.sum(jnp.dot(prices, media_df.mean(axis=0)))* n_time_periods
[39]: # Run optimization with the parameters of choice.
solution, kpi_without_optim, previous_budget_allocation = optimize_media.
↪find_optimal_budgets(
n_time_periods=n_time_periods,
media_mix_model=mmm,
budget=budget,
prices=prices,
media_scaler=media_scaler,
target_scaler=target_scaler)
Optimization terminated successfully
(Exit mode 0)
Current function value: -
Iterations: 20
Function evaluations: 300
Gradient evaluations: 20
[40]: # Obtain the optimal weekly allocation.
optimal_buget_allocation = prices * solution.x
optimal_buget_allocation
[40]: Array([- ,-,
-,- ,-,
6944.68 ], dtype=float32)
-,
[ ]:
Pre-post optimization budget allocation comparison for each channel Pre post optimization predicted target variable comparison
[41]: # Plot out pre post optimization budget allocation and predicted target␣
↪variable comparison.
plot.plot_pre_post_budget_allocation_comparison(media_mix_model=mmm,
kpi_with_optim=solution['fun'],
␣
↪kpi_without_optim=kpi_without_optim,
␣
↪optimal_buget_allocation=optimal_buget_allocation,
␣
↪previous_budget_allocation=previous_budget_allocation,
figure_size=(10,10))
[41]:
17
18