# Copyright 2024 The PyMC Labs Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Pre/post intervention fit experiment designs
"""
from typing import List, Union
import arviz as az
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from patsy import build_design_matrices, dmatrices
from sklearn.base import RegressorMixin
from causalpy.custom_exceptions import BadIndexException
from causalpy.plot_utils import plot_xY
from causalpy.pymc_models import PyMCModel
from causalpy.utils import round_num
from .base import BaseExperiment
LEGEND_FONT_SIZE = 12
[docs]
class PrePostFit(BaseExperiment):
"""
A base class for quasi-experimental designs where parameter estimation is based on
just pre-intervention data. This class is not directly invoked by the user.
"""
[docs]
def __init__(
self,
data: pd.DataFrame,
treatment_time: Union[int, float, pd.Timestamp],
formula: str,
model=None,
**kwargs,
) -> None:
super().__init__(model=model)
self.input_validation(data, treatment_time)
self.treatment_time = treatment_time
# set experiment type - usually done in subclasses
self.expt_type = "Pre-Post Fit"
# split data in to pre and post intervention
self.datapre = data[data.index < self.treatment_time]
self.datapost = data[data.index >= self.treatment_time]
self.formula = formula
# set things up with pre-intervention data
y, X = dmatrices(formula, self.datapre)
self.outcome_variable_name = y.design_info.column_names[0]
self._y_design_info = y.design_info
self._x_design_info = X.design_info
self.labels = X.design_info.column_names
self.pre_y, self.pre_X = np.asarray(y), np.asarray(X)
# process post-intervention data
(new_y, new_x) = build_design_matrices(
[self._y_design_info, self._x_design_info], self.datapost
)
self.post_X = np.asarray(new_x)
self.post_y = np.asarray(new_y)
# fit the model to the observed (pre-intervention) data
if isinstance(self.model, PyMCModel):
COORDS = {"coeffs": self.labels, "obs_indx": np.arange(self.pre_X.shape[0])}
self.model.fit(X=self.pre_X, y=self.pre_y, coords=COORDS)
elif isinstance(self.model, RegressorMixin):
self.model.fit(X=self.pre_X, y=self.pre_y)
else:
raise ValueError("Model type not recognized")
# score the goodness of fit to the pre-intervention data
self.score = self.model.score(X=self.pre_X, y=self.pre_y)
# get the model predictions of the observed (pre-intervention) data
self.pre_pred = self.model.predict(X=self.pre_X)
# calculate the counterfactual
self.post_pred = self.model.predict(X=self.post_X)
self.pre_impact = self.model.calculate_impact(self.pre_y[:, 0], self.pre_pred)
self.post_impact = self.model.calculate_impact(
self.post_y[:, 0], self.post_pred
)
self.post_impact_cumulative = self.model.calculate_cumulative_impact(
self.post_impact
)
[docs]
def input_validation(self, data, treatment_time):
"""Validate the input data and model formula for correctness"""
if isinstance(data.index, pd.DatetimeIndex) and not isinstance(
treatment_time, pd.Timestamp
):
raise BadIndexException(
"If data.index is DatetimeIndex, treatment_time must be pd.Timestamp."
)
if not isinstance(data.index, pd.DatetimeIndex) and isinstance(
treatment_time, pd.Timestamp
):
raise BadIndexException(
"If data.index is not DatetimeIndex, treatment_time must be pd.Timestamp." # noqa: E501
)
[docs]
def summary(self, round_to=None) -> None:
"""Print summary of main results and model coefficients.
:param round_to:
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers
"""
print(f"{self.expt_type:=^80}")
print(f"Formula: {self.formula}")
self.print_coefficients(round_to)
[docs]
def bayesian_plot(
self, round_to=None, **kwargs
) -> tuple[plt.Figure, List[plt.Axes]]:
"""
Plot the results
:param round_to:
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
"""
counterfactual_label = "Counterfactual"
fig, ax = plt.subplots(3, 1, sharex=True, figsize=(7, 8))
# TOP PLOT --------------------------------------------------
# pre-intervention period
h_line, h_patch = plot_xY(
self.datapre.index,
self.pre_pred["posterior_predictive"].mu,
ax=ax[0],
plot_hdi_kwargs={"color": "C0"},
)
handles = [(h_line, h_patch)]
labels = ["Pre-intervention period"]
(h,) = ax[0].plot(self.datapre.index, self.pre_y, "k.", label="Observations")
handles.append(h)
labels.append("Observations")
# post intervention period
h_line, h_patch = plot_xY(
self.datapost.index,
self.post_pred["posterior_predictive"].mu,
ax=ax[0],
plot_hdi_kwargs={"color": "C1"},
)
handles.append((h_line, h_patch))
labels.append(counterfactual_label)
ax[0].plot(self.datapost.index, self.post_y, "k.")
# Shaded causal effect
h = ax[0].fill_between(
self.datapost.index,
y1=az.extract(
self.post_pred, group="posterior_predictive", var_names="mu"
).mean("sample"),
y2=np.squeeze(self.post_y),
color="C0",
alpha=0.25,
)
handles.append(h)
labels.append("Causal impact")
ax[0].set(
title=f"""
Pre-intervention Bayesian $R^2$: {round_num(self.score.r2, round_to)}
(std = {round_num(self.score.r2_std, round_to)})
"""
)
# MIDDLE PLOT -----------------------------------------------
plot_xY(
self.datapre.index,
self.pre_impact,
ax=ax[1],
plot_hdi_kwargs={"color": "C0"},
)
plot_xY(
self.datapost.index,
self.post_impact,
ax=ax[1],
plot_hdi_kwargs={"color": "C1"},
)
ax[1].axhline(y=0, c="k")
ax[1].fill_between(
self.datapost.index,
y1=self.post_impact.mean(["chain", "draw"]),
color="C0",
alpha=0.25,
label="Causal impact",
)
ax[1].set(title="Causal Impact")
# BOTTOM PLOT -----------------------------------------------
ax[2].set(title="Cumulative Causal Impact")
plot_xY(
self.datapost.index,
self.post_impact_cumulative,
ax=ax[2],
plot_hdi_kwargs={"color": "C1"},
)
ax[2].axhline(y=0, c="k")
# Intervention line
for i in [0, 1, 2]:
ax[i].axvline(
x=self.treatment_time,
ls="-",
lw=3,
color="r",
)
ax[0].legend(
handles=(h_tuple for h_tuple in handles),
labels=labels,
fontsize=LEGEND_FONT_SIZE,
)
return fig, ax
[docs]
def ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, List[plt.Axes]]:
"""
Plot the results
:param round_to:
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
"""
counterfactual_label = "Counterfactual"
fig, ax = plt.subplots(3, 1, sharex=True, figsize=(7, 8))
ax[0].plot(self.datapre.index, self.pre_y, "k.")
ax[0].plot(self.datapost.index, self.post_y, "k.")
ax[0].plot(self.datapre.index, self.pre_pred, c="k", label="model fit")
ax[0].plot(
self.datapost.index,
self.post_pred,
label=counterfactual_label,
ls=":",
c="k",
)
ax[0].set(
title=f"$R^2$ on pre-intervention data = {round_num(self.score, round_to)}"
)
ax[1].plot(self.datapre.index, self.pre_impact, "k.")
ax[1].plot(
self.datapost.index,
self.post_impact,
"k.",
label=counterfactual_label,
)
ax[1].axhline(y=0, c="k")
ax[1].set(title="Causal Impact")
ax[2].plot(self.datapost.index, self.post_impact_cumulative, c="k")
ax[2].axhline(y=0, c="k")
ax[2].set(title="Cumulative Causal Impact")
# Shaded causal effect
ax[0].fill_between(
self.datapost.index,
y1=np.squeeze(self.post_pred),
y2=np.squeeze(self.post_y),
color="C0",
alpha=0.25,
label="Causal impact",
)
ax[1].fill_between(
self.datapost.index,
y1=np.squeeze(self.post_impact),
color="C0",
alpha=0.25,
label="Causal impact",
)
# Intervention line
# TODO: make this work when treatment_time is a datetime
for i in [0, 1, 2]:
ax[i].axvline(
x=self.treatment_time,
ls="-",
lw=3,
color="r",
label="Treatment time",
)
ax[0].legend(fontsize=LEGEND_FONT_SIZE)
return (fig, ax)
[docs]
class InterruptedTimeSeries(PrePostFit):
"""
A wrapper around PrePostFit class
:param data:
A pandas dataframe
:param treatment_time:
The time when treatment occured, should be in reference to the data index
:param formula:
A statistical model formula
:param model:
A PyMC model
Example
--------
>>> import causalpy as cp
>>> df = (
... cp.load_data("its")
... .assign(date=lambda x: pd.to_datetime(x["date"]))
... .set_index("date")
... )
>>> treatment_time = pd.to_datetime("2017-01-01")
>>> seed = 42
>>> result = cp.InterruptedTimeSeries(
... df,
... treatment_time,
... formula="y ~ 1 + t + C(month)",
... model=cp.pymc_models.LinearRegression(
... sample_kwargs={
... "target_accept": 0.95,
... "random_seed": seed,
... "progressbar": False,
... }
... )
... )
"""
expt_type = "Interrupted Time Series"
supports_ols = True
supports_bayes = True
[docs]
class SyntheticControl(PrePostFit):
"""A wrapper around the PrePostFit class
:param data:
A pandas dataframe
:param treatment_time:
The time when treatment occured, should be in reference to the data index
:param formula:
A statistical model formula
:param model:
A PyMC model
Example
--------
>>> import causalpy as cp
>>> df = cp.load_data("sc")
>>> treatment_time = 70
>>> seed = 42
>>> result = cp.SyntheticControl(
... df,
... treatment_time,
... formula="actual ~ 0 + a + b + c + d + e + f + g",
... model=cp.pymc_models.WeightedSumFitter(
... sample_kwargs={
... "target_accept": 0.95,
... "random_seed": seed,
... "progressbar": False,
... }
... ),
... )
"""
expt_type = "SyntheticControl"
supports_ols = True
supports_bayes = True
[docs]
def bayesian_plot(self, *args, **kwargs) -> tuple[plt.Figure, List[plt.Axes]]:
"""
Plot the results
:param round_to:
Number of decimals used to round results. Defaults to 2. Use "None" to
return raw numbers.
:param plot_predictors:
Whether to plot the control units as well. Defaults to False.
"""
# call the super class method
fig, ax = super().bayesian_plot(*args, **kwargs)
# additional plotting functionality for the synthetic control experiment
plot_predictors = kwargs.get("plot_predictors", False)
if plot_predictors:
# plot control units as well
ax[0].plot(self.datapre.index, self.pre_X, "-", c=[0.8, 0.8, 0.8], zorder=1)
ax[0].plot(
self.datapost.index, self.post_X, "-", c=[0.8, 0.8, 0.8], zorder=1
)
return fig, ax