# Copyright 2024 The PyMC Labs Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Regression kink design
"""
import warnings # noqa: I001
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from patsy import build_design_matrices, dmatrices
from causalpy.plot_utils import plot_xY
from .base import BaseExperiment
from causalpy.utils import round_num
from causalpy.custom_exceptions import (
DataException,
FormulaException,
)
from causalpy.utils import _is_variable_dummy_coded
LEGEND_FONT_SIZE = 12
[docs]
class RegressionKink(BaseExperiment):
"""Regression Kink experiment class."""
supports_ols = False
supports_bayes = True
[docs]
def __init__(
self,
data: pd.DataFrame,
formula: str,
kink_point: float,
model=None,
running_variable_name: str = "x",
epsilon: float = 0.001,
bandwidth: float = np.inf,
**kwargs,
):
super().__init__(model=model)
self.expt_type = "Regression Kink"
self.data = data
self.formula = formula
self.running_variable_name = running_variable_name
self.kink_point = kink_point
self.epsilon = epsilon
self.bandwidth = bandwidth
self.input_validation()
if self.bandwidth is not np.inf:
fmin = self.kink_point - self.bandwidth
fmax = self.kink_point + self.bandwidth
filtered_data = self.data.query(f"{fmin} <= x <= {fmax}")
if len(filtered_data) <= 10:
warnings.warn(
f"Choice of bandwidth parameter has lead to only {len(filtered_data)} remaining datapoints. Consider increasing the bandwidth parameter.", # noqa: E501
UserWarning,
)
y, X = dmatrices(formula, filtered_data)
else:
y, X = dmatrices(formula, self.data)
self._y_design_info = y.design_info
self._x_design_info = X.design_info
self.labels = X.design_info.column_names
self.y, self.X = np.asarray(y), np.asarray(X)
self.outcome_variable_name = y.design_info.column_names[0]
COORDS = {"coeffs": self.labels, "obs_indx": np.arange(self.X.shape[0])}
self.model.fit(X=self.X, y=self.y, coords=COORDS)
# score the goodness of fit to all data
self.score = self.model.score(X=self.X, y=self.y)
# get the model predictions of the observed data
if self.bandwidth is not np.inf:
xi = np.linspace(fmin, fmax, 200)
else:
xi = np.linspace(
np.min(self.data[self.running_variable_name]),
np.max(self.data[self.running_variable_name]),
200,
)
self.x_pred = pd.DataFrame(
{self.running_variable_name: xi, "treated": self._is_treated(xi)}
)
(new_x,) = build_design_matrices([self._x_design_info], self.x_pred)
self.pred = self.model.predict(X=np.asarray(new_x))
# evaluate gradient change around kink point
mu_kink_left, mu_kink, mu_kink_right = self._probe_kink_point()
self.gradient_change = self._eval_gradient_change(
mu_kink_left, mu_kink, mu_kink_right, epsilon
)
@staticmethod
def _eval_gradient_change(mu_kink_left, mu_kink, mu_kink_right, epsilon):
"""Evaluate the gradient change at the kink point.
It works by evaluating the model below the kink point, at the kink point,
and above the kink point.
This is a static method for ease of testing.
"""
gradient_left = (mu_kink - mu_kink_left) / epsilon
gradient_right = (mu_kink_right - mu_kink) / epsilon
gradient_change = gradient_right - gradient_left
return gradient_change
def _probe_kink_point(self):
"""Probe the kink point to evaluate the predicted outcome at the kink point and
either side."""
# Create a dataframe to evaluate predicted outcome at the kink point and either
# side
x_predict = pd.DataFrame(
{
self.running_variable_name: np.array(
[
self.kink_point - self.epsilon,
self.kink_point,
self.kink_point + self.epsilon,
]
),
"treated": np.array([0, 1, 1]),
}
)
(new_x,) = build_design_matrices([self._x_design_info], x_predict)
predicted = self.model.predict(X=np.asarray(new_x))
# extract predicted mu values
mu_kink_left = predicted["posterior_predictive"].sel(obs_ind=0)["mu"]
mu_kink = predicted["posterior_predictive"].sel(obs_ind=1)["mu"]
mu_kink_right = predicted["posterior_predictive"].sel(obs_ind=2)["mu"]
return mu_kink_left, mu_kink, mu_kink_right
def _is_treated(self, x):
"""Returns ``True`` if `x` is greater than or equal to the treatment threshold.""" # noqa: E501
return np.greater_equal(x, self.kink_point)
[docs]
def summary(self, round_to=None) -> None:
"""Print summary of main results and model coefficients.
:param round_to:
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers
"""
print(
f"""
{self.expt_type:=^80}
Formula: {self.formula}
Running variable: {self.running_variable_name}
Kink point on running variable: {self.kink_point}
Results:
Change in slope at kink point = {round_num(self.gradient_change.mean(), round_to)}
"""
)
self.print_coefficients(round_to)
[docs]
def bayesian_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]:
"""Generate plot for regression kink designs."""
fig, ax = plt.subplots()
# Plot raw data
sns.scatterplot(
self.data,
x=self.running_variable_name,
y=self.outcome_variable_name,
c="k", # hue="treated",
ax=ax,
)
# Plot model fit to data
h_line, h_patch = plot_xY(
self.x_pred[self.running_variable_name],
self.pred["posterior_predictive"].mu,
ax=ax,
plot_hdi_kwargs={"color": "C1"},
)
handles = [(h_line, h_patch)]
labels = ["Posterior mean"]
# create strings to compose title
title_info = f"{round_num(self.score.r2, round_to)} (std = {round_num(self.score.r2_std, round_to)})"
r2 = f"Bayesian $R^2$ on all data = {title_info}"
percentiles = self.gradient_change.quantile([0.03, 1 - 0.03]).values
ci = (
r"$CI_{94\%}$"
+ f"[{round_num(percentiles[0], round_to)}, {round_num(percentiles[1], round_to)}]"
)
grad_change = f"""
Change in gradient = {round_num(self.gradient_change.mean(), round_to)},
"""
ax.set(title=r2 + "\n" + grad_change + ci)
# Intervention line
ax.axvline(
x=self.kink_point,
ls="-",
lw=3,
color="r",
label="treatment threshold",
)
ax.legend(
handles=(h_tuple for h_tuple in handles),
labels=labels,
fontsize=LEGEND_FONT_SIZE,
)
return fig, ax