Interpretable ML 3: What nobody tells you about SHAP values

Thomas Bury
11 min readFeb 11, 2024

--

Photo by Suhash Villuri on Unsplash

Our journey through the intricacies of interpretable machine learning continues! In this episode, we’ll venture beyond simplistic interpretations and delve into the practicalities and pitfalls of interpreting models like GBMs with non-normal losses. Buckle up, it’s quite a journey until we reach our goal!

Disclaimer: It’s Not Quite Black and White

First things first, let’s address the elephant in the room: is interpretability a myth? Not entirely. But it’s certainly not a walk in the park. Linear models might seem transparent, but for complex relationships, they offer mere glimpses, not the full picture. And while GBMs, with their impressive predictive power, aren’t complete black boxes, interpreting them requires navigating complexities like collinearity and correlation traps.

Tools of the Trade: Python Packages and Beyond

Don’t worry, you’re not unarmed! Packages like ARFS, scicomap, and geomapviz are ready to aid your analysis. But remember, interpretation is just one step in the journey. True understanding comes from aa comprehensive approach often involving deeper analysis or even behavioral experiments to validate your findings.

Correlation vs. Causation: A Crucial Distinction

Remember, model interpretation tells you how the model makes predictions, not the why behind the data. It’s the difference between correlation and causation. A predictive model captures correlations, not causal relationships. So, relying solely on interpretation for causal decisions can be misleading. When causation matters, turn to dedicated tools like causal statistical inference.

SHAPing Understanding: A Powerful Tool, but Use It Wisely

SHAP (SHapley Additive exPlanations) offers valuable insights into GBMs, but beware! Correlation can distort its interpretations, leading to false counterfactuals and inflated variance. If correlation rears its ugly head, group SHAP values to mitigate its impact. Also, for significant correlation structures, ditch What-Ifs and counterfactuals, and opt for ALE instead.

Scale Matters: Navigating the Numerical Maze

SHAP values often live on a log scale. Don’t get lost in translation! Convert them back to the original scale if you need additive interpretations. Remember, though, that true additive interpretations might be elusive in such cases. We’ll see that in the example below.

Deriving Similarity: A Complex Quest

Finding similar instances in complex datasets is no easy feat. Currently, there’s no magic spell for this one, especially when dealing with mixed data types and varying weights. Stay critical and question the purpose of your interpretation. Is it for regulation, causal inference, or plain curiosity? Different goals demand different tools, like causal models or A/B testing.

Data Dive: Insurance Claims and Predicting Pure Premium

Our case study delves into open insurance claims data. We’ll predict the “pure premium,” the average claim amount per unit of exposure (year fraction). If you did not understand the latter, it does not matter. The goal is to utilize a non-Gaussian distribution, such as the Tweedie distribution, which belongs to the exponential distribution family.

Remember, the key is to be critical, cautious, and use the right tools for the job. The path to understanding may be winding, but with the right mindset and approach, we can navigate the interpretability maze and unlock valuable insights.

Open data on insurance claims


# https://www.openml.org/search?type=data&sort=runs&id=41214&status=active
df = load_french_mtpl(as_frame=True)

targets = ['ClaimNb', 'ClaimAmount', 'PurePremium', 'Frequency', 'AvgClaimAmount']
exposure = ['Exposure']
id_cols = ['IDpol', 'Exposure', 'index']
predictors = ['Area', 'BonusMalus', 'Density', 'DrivAge', 'Region', 'VehAge', 'VehBrand', 'VehGas', 'VehPower']
# order the columns and put the formerly used predictors first, so that they won't be removed if they are correlated with other columns (the other correlated columns will be removed though)
predictors = sorted(predictors)
# set columns to keep
keep_cols = id_cols + targets + exposure + predictors
# preprocessing and splitting
ord_enc = OrdinalEncoderPandas(exclude_cols=id_cols + ['split_idx', 'fold_nbr'], return_pandas_categorical=False)
preprocessing_pipe = Pipeline(steps=[("encoder", ord_enc)])
df_trans = preprocessing_pipe.fit_transform(df)


# Models comparison on the test set only
df_trans["PurePremium"] = df_trans["ClaimAmount"] / df_trans["Exposure"]
mask_train = df_trans.split_idx != 'test'
train = df_trans[mask_train].copy()
test = df_trans[~mask_train].copy()
image by author
import arfs
assoc_mat = arfs.association.association_matrix(train[predictors],
sample_weight=None,
nom_nom_assoc='theil',
num_num_assoc='pearson',
nom_num_assoc='correlation_ratio',
n_jobs=-1, handle_na='drop',
nom_nom_comb=None,
num_num_comb=None,
nom_num_comb=None)
assoc_mat = arfs.association.xy_to_matrix(assoc_mat)
f, ax = arfs.association.plot_association_matrix(assoc_mat,
suffix_dic=None,
ax=None,
cmap='PuOr',
cbarlabel=None,
figsize=None,
show=False,
cbar_kw=None,
imgshow_kw=None,
annotate=False)
image by author

Non-Gaussian Model

Now, it’s time to unleash the models! But hold on, let’s introduce a new twist: the Tweedie distribution, different from the usual Gaussian distribution.

For this illustration, we’ll train a lightgbm regressor, known for its efficiency and flexibility. Buckle up, data scientists, as we dive into the code with these basic parameters:

# Model training with Tweedie distribution
model = lightgbm.LGBMRegressor(
objective="tweedie",
tweedie_variance_power=1.5, # adjust based on your data exploration
learning_rate=0.1,
n_estimators=100,
random_state=42,
)
model.fit(train[predictors], train["PurePremium"])

Are the results what we expected? Do they align with our intuition about insurance claims? More importantly, do they offer valuable insights for risk assessment and technical pricing strategies?

Unveiling the Hidden: Introducing Interpreters

Hold onto your data detective hats, as we embark on a journey into the intriguing world of interpreters, especially what nobody told you about SHAP for non-Gaussian losses! These powerful tools will crack open the black box of our model, revealing how individual features and their intricate relationships influence the final prediction. Get ready to witness interpretability magic as we illuminate the hidden workings of machine learning!

SHAP (aggregated)

Image by author

SHAP values provide an “additive attribution” for our model, meaning they explain how each feature contributes to the final prediction. But hold on, there’s a twist! For lightGBM (and other GBMs with log-link functions), these attributions live on a logarithmic scale. Think of it as a secret code waiting to be cracked. In our example, the SHAP values are not expressed in eur units but in log(eur) units!

Fortunately, we have the tools to decipher this code. By converting the base value and multiplying by relativities, we can transform these log-scale Shapley values into “classical relativities” like those found in GLMs. This allows us to interpret them as multiplicative attributions, where increasing feature values lead to exponential increases in predicted values.

Image by author
explainer = shap.Explainer(mod.model)
shap_values = explainer(test[predictors])

shap_exp = copy.deepcopy(shap_values)
# converting the base to natural units
shap_exp.base_values = np.exp(shap_exp.base_values)
# converting values to natural units and substrating the average value (centering)
# g(mu) = x1 + ... + xp where g is the log-link
# the SHAP values are therefore the log-unit, the last being the base value (average)
# taking the exp transform the sum into a prod (exp(sum_i x_i) = prod_i exp(x_i))
shap_exp.values = np.cumprod((np.exp(shap_exp.values)), axis=1)*shap_exp.base_values[0]

This method transforms the Shapley values back to the original units (euros in our case), but it results in a “multiplicative attribution” rather than the desired “additive” one. While informative, it’s important to remember this key difference.

Now, for the grand finale! We can convert these Shapley values into an additive attribution using differences from the base value. However, this is a bit of a magic trick and can be misleading. Interpreting these values can be tricky, so tread carefully!

shap_exp_centered = copy.deepcopy(shap_values)
# converting the base to natural units
shap_exp_centered.base_values = np.exp(shap_exp_centered.base_values)
# converting values to natural units and substrating the average value (centering)
# g(mu) = x1 + ... + xp where g is the log-link
# the SHAP values are therefore the log-unit, the last being the base value (average)
# taking the exp transform the sum into a prod (exp(sum_i x_i) = prod_i exp(x_i))
shap_exp_centered.values = np.cumprod((np.exp(shap_exp_centered.values)), axis=1)*shap_exp_centered.base_values[0]
target_diff = np.diff(shap_exp_centered.values, axis=1)
zero_value = np.zeros((len(shap_exp_centered.base_values),1))
first_contrib_minus_base = np.expand_dims((shap_exp_centered.values[:, 0] - shap_exp_centered.base_values[0]), axis=1)
shap_exp_centered.values = np.concatenate((target_diff, first_contrib_minus_base), axis=1)

With these decoding techniques in hand, we can unlock the true meaning of Shapley values and gain valuable insights into how each feature shapes our model’s predictions. Remember, interpretation isn’t always straightforward, but with the right tools and understanding, we can navigate these complexities and glean valuable knowledge from our models.

Feature Importance

Now, let’s shift gears and dive into the world of “technical charts” — think of them as cryptic maps revealing the inner workings of our model. Specifically, we’ll explore “feature importance” at the portfolio level, uncovering the overall impact of each feature on the predicted premium.

But wait, a word of caution! These charts speak a specialized language, not readily intelligible to everyone. They show absolute and directional importance (pushing the premium up or down), but not in euros, so don’t expect straight-up interpretations. And while insightful, boxplots aren’t always the easiest read for everyone.

Direction of Influence: This chart uses arrows to visually represent whether a feature typically “pushes” the premium up or down. Think of it as a tug-of-war, where each feature influences the outcome.

Absolute Impact: This chart uses boxplots to highlight the overall strength of each feature’s influence, regardless of direction. Imagine it as a spectrum, with features closer to the center having less impact than those at the extremes.

s_df = pd.DataFrame(shap_exp_centered.values, columns=shap_exp_centered.feature_names)
s_df_med = s_df.median().sort_values(ascending=False)
col_order = s_df_med.index.tolist()
# define the colors and markers
c = [ '#db0421' if x>=0 else '#0452b8' for x in s_df_med.values]
m = [ 'push up' if x>=0 else 'pull down' for x in s_df_med.values]

# create the figure
fig = plt.figure(constrained_layout=True, figsize=(9, 4))
ax_dict = fig.subplot_mosaic([["norm", "abs"]])

# left panel illustrates if a rating factor pushes the premium up or down, in average
ax_dict["norm"] = sns.boxplot(data=s_df, orient="h", showfliers=False, order=list(col_order), linewidth=0, medianprops=dict(color="black", linewidth=0), ax=ax_dict["norm"], color="#d9dadb", zorder=-1)
ax_dict["norm"] = sns.scatterplot(x=s_df_med.values, y=s_df_med.index, c=c, style=m, markers={'push up': '>', 'pull down': '<'}, ax=ax_dict["norm"], s=25)

# the absolute impact of the rating factors on the premium
s_df = pd.DataFrame(np.abs(shap_exp_centered.values), columns=shap_exp_centered.feature_names)
col_order = s_df.median().sort_values(ascending=False).index.tolist()
ax_dict["abs"] = sns.boxplot(data=s_df, orient="h", showfliers=False, order=list(col_order), linewidth=0, medianprops=dict(color="#e602de", linewidth=3), ax=ax_dict["abs"], color="#d9dadb")
plt.show();
Image by author

Remember, these charts offer valuable insights for technical folks who understand their nuances. But don’t worry, we’ll delve deeper into more intuitive interpretations below!

Taming Collinearity: Grouping for Clarity

Remember those numerous rating factors we explored earlier? While valuable, their sheer number can be overwhelming. It’s like searching for a specific star in a crowded galaxy! But fear not, intrepid data explorers! We have a powerful tool: collinearity mitigation.

This fancy term boils down to grouping related features, creating meta-characteristics that offer a clearer picture. Think of it like forming teams instead of analyzing individual players — it reveals broader patterns and trends.

In our case, we’ve grouped these features into three key categories:

  • Vehicles: Features like age, brand, fuel type, and power paint a portrait of the car itself.
  • Driver: Bonus-malus coefficient and age reveal the driver’s risk profile.
  • Geospatial: Density, region, and area provide insights into the driving environment.
# Group Shapley values
s_df = pd.DataFrame(shap_exp_centered.values, columns=shap_exp_centered.feature_names)
groups = {"spatial": ["Density", "Region", "Area"],
"driver": ["BonusMalus", "DrivAge"],
"vehicle": ["VehAge", "VehBrand", "VehGas", "VehPower"]}
# sum the additive attributions
cols_grouped = set()
for cluster_name, cols in groups.items():
cols_grouped = cols_grouped.union(cols)
s_df[cluster_name] = s_df[cols].values.sum(axis=1)
if cols_grouped:
s_df = s_df.drop(columns=cols_grouped)
# compute statistics and sort the groups by importance
s_df_med = s_df.median().sort_values(ascending=False)
x = s_df_med.index
y = 100*s_df_med.values/s_df_med.values.sum()
c = [ '#9021ff' if x>=0 else '#d9ad00' for x in y]
ind = np.arange(len(y))

fig = plt.figure(constrained_layout=True, figsize=(8, 4))
ax_dict = fig.subplot_mosaic([["norm", "abs"]])
ax_dict["norm"].barh(s_df_med.index, s_df_med.values, color=c, zorder=0)
ax_dict["norm"].set_yticks(s_df_med.index)
ax_dict["norm"].set_yticklabels(s_df_med.index)
ax_dict["norm"].bar_label(ax_dict["norm"].containers[0], padding=15, fmt='%.0f', color='#999999', size=11, weight='bold')
ax_dict["norm"].autoscale(enable=True, axis='x', tight=False)
ax_dict["norm"].set_xmargin(0.15)
ax_dict["norm"].set_title('Average directional impact on the PP')
col_order = s_df_med.index.tolist()
ax_dict["norm"] = sns.boxplot(data=s_df, orient="h", showfliers=False, order=col_order,
width=0.25, linewidth=0, medianprops=dict(color="white", linewidth=2),
ax=ax_dict["norm"], color='#383838')
s_df = pd.DataFrame(np.abs(s_df.values), columns=s_df.columns)
s_df_abs = s_df.median().sort_values(ascending=True)
col_order = s_df_abs.index.tolist()
ax_dict["abs"].barh(s_df_abs.index, s_df_abs.values, zorder=0)
ax_dict["abs"].set_yticks(s_df_abs.index)
ax_dict["abs"].set_yticklabels(s_df_abs.index)
ax_dict["abs"].bar_label(ax_dict["abs"].containers[0], padding=3, fmt='%.0f', color='#999999', size=11, weight='bold')
ax_dict["abs"].autoscale(enable=True, axis='x', tight=False)
ax_dict["abs"].set_xmargin(0.15)
ax_dict["abs"].set_title('Average absolute impact on the PP')
ax_dict["abs"] = sns.boxplot(data=s_df, orient="h", showfliers=False, order=col_order,
width=0.25, linewidth=0, medianprops=dict(color="white", linewidth=2),
ax=ax_dict["abs"], color='#383838')
plt.show();
Image by author

Just like before, we’ll use two charts to analyze these groups:

1. Directional Influence: This time, the bars show the average impact of each group on the predicted premium, colored to indicate whether it typically pushes the price up or down.

2. Absolute Impact: Here, the boxplots reveal the strength of each group’s influence, regardless of direction. Imagine it as a tug-of-war, where wider boxes depict groups with stronger pulling power.

Remember, these insights are more readily interpretable than individual features. They offer a zoomed-out perspective, helping us grasp the bigger picture of how various aspects contribute to the final premium prediction.

Now, don’t get lost in the code! Think of it as a secret recipe enabling us to create these informative charts. The important takeaway is that by grouping features and analyzing them collectively, we gain a clearer understanding of how several factors interplay to influence our model’s predictions. So, the next time you encounter a multitude of features, remember — there’s power in grouping!

One step deeper in the rabbit hole of collinearity: Shapley vs. iBreakdown

Now, let’s compare two popular tools for explaining our model’s predictions: Shapley and iBreakdown. Both help us understand how individual features contribute to the outcome, but each has its own unique twist.

Shapley is an additive attribution (or can be converted into a multipicative attribution) but does not include interactions.

iBreakdown, on the other hand, can (partially) capture interactions.

While iBreakdown offers intuitive interpretations and shines when facing model interactions, its calculations require more time (think quadratic complexity). Shapley, on the other hand, works faster but may require additional effort to unpack its results.

The best choice depends on your specific needs and preferences. Do you prioritize speed or intuitiveness? Are interactions a major concern? Consider these factors and choose the tool that best empowers you to unlock the secrets of your model’s predictions!

# Plot Waterfall Plot
f, ax = plt.subplots()
shap.plots.waterfall(shap_exp_centered[10], show=False)

import dalex as dx
lgb_exp = dx.Explainer(mod.model, test[predictors], test["PurePremium"], label="PP")
bdi = lgb_exp.predict_parts(X_test.iloc[10, :], type='break_down_interactions')
SHAP on the left panel, iBreakDown on the right panel, including interaction between vehicle power and vehicle age. Image by author

Unveiling the Hidden: A Final Glimpse

Our journey into the challenging world of model interpretability has reached its end. We’ve explored various techniques, including converting Shapley values into natural units, effectively mimicking the GLM average equivalent and providing additional context for GBM users. We’ve also tackled part of the collinearity issues that can affect SHAP values, offering a more robust understanding of feature contributions.

Remember, interpretability is an ongoing quest, not a single destination. As models evolve and new challenges arise, so too will our tools and techniques. But with the knowledge we’ve gained, we are better equipped to navigate this complex landscape and extract valuable insights from our models.

So, keep exploring, keep questioning, and keep striving to understand the “why” behind the “what.” After all, it’s not just about making predictions, it’s about building trust and uncovering the hidden stories within our data. As Albert Einstein said, “The important thing is not to stop questioning. Curiosity has its own reason for existing.” Let’s embrace that curiosity and continue to unlock the true potential of interpretable models!

--

--

Thomas Bury

Physicist by passion and training, Data Scientist and MLE for a living (it's fun too), interdisciplinary by conviction.