Explain This?! - Model Interpretability
Interpretability also popularly known as human-interpretable interpretations (HII) of a machine learning model is the extent to which a human (including non-experts in machine learning) can understand the choices taken by models in their decision-making process.
A machine learning model by itself consists of an algorithm which tries to learn latent patterns and relationships from data without hard-coding fixed rules. This makes explaining how a model works in a business context often difficult to do. This has driven a recent surge of interest into the idea of being able to interpret models to explain the how and why models make the predictions they do.
Interpretable ML - The Big Picture
Key Concepts
Intrinsic vs. Posthoc
Certain models, such as linear or tree-based models, have been designed in such a way that they are intrinsically interpretable by nature. For example, linear models are monotone and the dependant variable scales linearly with a predictor variable; this means the weights, or coefficients, that controls this behaviour can be understood simply. In the same manner, simple tree-based methods have features and thresholds which make human sense.
By comparison, post-hoc interpretability refers to techniques that are applied after training to models that can be considered 'black box' where the internals are hard to decipher. An example of a post-hoc method would be permutation feature importance, where a trained model is applied to a dataset when the input data has been shuffled, and the importance of the feature is assessed by the change of the prediction from what was expected.
Model Specific vs. Model Agnostic
Interpretability techniques that are restricted to the model it is be applied on is considered model specific. Therefore, intrinsic techniques are by definition model specific. Model agnostic techniques are those that can be applied on any trained model, hence are often post-hoc.
Local vs. Global
If we are able to explain an entire model's behaviour, i.e. describe how a model will perform against the whole feature space, then we call such a technique global. In contrast, if we can explain why a specific prediction was made, then the technique is considered to have local interpretability.
Techniques
In this post, I will describe some commonly used post-hoc, model-agnostic techniques. I will present high level theory and give example applications using models trained on MNIST Digits and the Boston House Prices datasets.
The following code snippet shows how I quickly and naively build a few regression and classification models.
from sklearn import datasets import pandas as pd # load data iris = datasets.load_iris() boston = datasets.load_boston() boston_X = pd.DataFrame(boston.data, columns=boston.feature_names) boston_y = pd.DataFrame(boston.target, columns=['prices']) iris_X = pd.DataFrame(iris.data, columns=iris.feature_names) iris_y = pd.DataFrame(iris.target, columns=['flower']) from sklearn.linear_model import LinearRegression, LogisticRegression from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor from xgboost.sklearn import XGBClassifier, XGBRegressor from sklearn.model_selection import train_test_split from sklearn.metrics import accuracy_score, r2_score def train_model(X, y, model, metric): X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.8, random_state=42) model.fit(X_train.values, y_train.values) y_test_hat = model.predict(X_test.values) score = metric(y_test, y_test_hat) print(f'Model {model.__class__.__name__} achieved a {score} score.') return model lin_reg = train_model(boston_X, boston_y, LinearRegression(), r2_score) dec_tree_reg = train_model(boston_X, boston_y, DecisionTreeRegressor(), r2_score) ran_for_reg = train_model(boston_X, boston_y, RandomForestRegressor(), r2_score) xgb_reg = train_model(boston_X, boston_y, XGBRegressor(), r2_score) log_reg_clf = train_model(iris_X, iris_y, LogisticRegression(), accuracy_score) dec_tree_clf = train_model(iris_X, iris_y, DecisionTreeClassifier(), accuracy_score) ran_for_clf = train_model(iris_X, iris_y, RandomForestClassifier(), accuracy_score) xgb_clf = train_model(iris_X, iris_y, XGBClassifier(), accuracy_score)
Feature Importance
Feature importance is a generic term for the degree to which a predictive model relies on a feature. To understand if input features are important, and how significantly they are so, can give incredible insight into what are the driving factors in a predictive model.
The intuition behind most techniques is that you can determine how much a model relies on a particular feature by making slight changes to the values of that feature in the dataset and comparing the predictions with the baseline. You can say that there is high reliance on that feature if small changes in the feature have a large effect on the predictions. In contrast, a feature is not important to a model if you can change values in that feature with the predictions remaining mostly unchanged.
ELI5
This package determines the importance of a feature by attempting to make a prediction without that feature present and comparing the new score with the original. It avoids retraining any model without the feature present by replacing all data within that feature with random noise. In this case, the random noise is generating by randomly shuffling all data within that feature column (hence they are selecting data with the same distribution) and then comparing predictions.
Global
import eli5 from eli5.sklearn import PermutationImportance perm = PermutationImportance(lin_reg).fit(boston_X, boston_y) eli5.show_weights(perm, feature_names=boston_X.columns.values) perm = PermutationImportance(log_reg_clf).fit(iris_X, iris_y) eli5.show_weights(perm, feature_names=iris_X.columns.values)
ELI5 feature importance on Boston House Price dataset
ELI5 feature importance on Iris dataset
The output on the top shows the feature importance of a LinearRegression when applied to the Boston House Price dataset. It shows that the RM and LSTAT features are the most important. The output on the bottom then shows that petal length and petal width are the most important predictors for the LogisticRegression classifier. For more information on how to interpret the weight figure, read here.
SHAP
This package applies the Shapley value concept from game theory to machine learning. The Shapley value of a feature (player) is the average marginal contribution towards the prediction (payout) across all feature values (coalitions). The marginal contribution is based on the difference from the average prediction made. Edward Ma and Christoph Molnar explain this idea with some fantastic diagrams in a very approachable manor.
Global
To visualise feature importance using Shapley values as your metric, you initialise an 'Explainer' class with the model you are inspecting, along with the input data you trained the model with. Then you calculate the Shapley values and feed this array into one of the various plotting methods available from the SHAP library.
In the plot below, we show the average absolute value of all of the Shapley values by input feature; this output can be interpreted to say that the RM, DIS, and LSTAT features are the most impactful in influencing predictions across the whole set of input data.
explainer = shap.LinearExplainer(lin_reg, boston_X) shap_values = explainer.shap_values(boston_X) shap.summary_plot(shap_values, boston_X, plot_type='bar')
Global SHAP feature importance on Boston House Price dataset
To look deeper into feature importance, you can visualise the distribution of individual Shapley values by input feature. This can be useful in determining if there are subsets of the input data where there are skewed values from different features which could be significantly affecting prediction values.
For example, in the following plot you can see that the CRIM feature will have a large negative effect on the predicted house prices when the feature value is low; but when the feature value is high, it will have no effect on the predicted values.
shap.summary_plot(shap_values, boston_X)
Global SHAP feature importance on Boston House Price dataset
Local
explainer = shap.LinearExplainer(lin_reg, boston_X) shap_values = explainer.shap_values(boston_X) # visualize the first prediction's explanation (use matplotlib=True to avoid Javascript) shap.force_plot(explainer.expected_value, shap_values[0,:], boston_X.iloc[0,:])
Local SHAP feature importance on Boston House Price dataset
You can interpret this output by knowing that the base value refers to the expected output of the model given the training set the model is built on. The arrows and associated features then show which features of this particular sample has contributed to a shift away from the average output. In this example, the LSTAT and PTRATIO has had the largest effect in predicting the value of 28.45.
Partial Dependency Plots
The partial dependence plot (PDP or PD plot for short) shows the marginal effect one or two features have on the predicted outcome of a machine learning model (J. H. Friedman 200127). A partial dependence plot can show whether the relationship between the target and a feature is linear, monotonous or more complex. For example, when applied to a linear regression model, partial dependence plots always show a linear relationship. The idea of PD plots is similar to plotting the gradient, or value of the partial derivative, of one or two input features.
The reason you can only can only effectively plot dependency for a maximum of two input variables is an artefact of our 3 dimensional world and how visualisations like these are shown on a 2-dimensional screen; mathematically speaking, the analysis would work for any number of dimensions, but interpretation works best at most 2 or 3 dimensions.
PDPBox
One Python package you can use to visualise these marginal effects is called PDPBox. You can plot a partial dependence graph with
from matplotlib import pyplot as plt from pdpbox import pdp, get_dataset, info_plots feature = 'LSTAT' # Create the data that we will plot pdp_goals = pdp.pdp_isolate( model=ran_for_reg, dataset=boston_X, model_features=boston_X.columns.values, feature=feature) # plot it pdp.pdp_plot(pdp_goals, feature) plt.show()
and the output will be like the graph below. The X axis shows the feature values you are plotting against, the y axis is the difference in predicted values when compared to the left most X value. The light blue band shows the confidence level in this analysis. With the plot below, you can say that an increasing LSTAT value (the % of the population that is deemed 'lower status'), has a negative effect on house prices until a value of 20 where it has very little effect.
Partial dependacy plot for LSTAT feature
The graphs below show the effects of two more variables. RM, the number of rooms in a house has a positive effect on house prices after a value of 6, and CHAS, a dummy variable, has no effect on house prices.
Partial dependacy plot for number of rooms feature
Partial dependacy plot for CHAS feature
You can also use PDPBox to visualise the effect of two variables on the predicted value with
from matplotlib import pyplot as plt from pdpbox import pdp, get_dataset, info_plots features = ['LSTAT', 'RM'] # Create the data that we will plot pdp_goals = pdp.pdp_interact( model=ran_for_reg, dataset=boston_X, model_features=boston_X.columns.values, features=features) # # plot it pdp.pdp_interact_plot(pdp_goals, features, plot_type='grid') plt.show()
Partial dependacy plot interaction plot
The lighter indicates a strong positive effect, and the darker colour indicates the opposite. This plot shows that a higher value of RM and a lower value of LSTAT has a greater positive effect on house prices.
Skater
You can find other packages that can do similar analysis to determine partial dependence. Skater, built by Oracle, is another example which has an equally simple API and some beautiful and clear visualisations. Here are the equivalent partial dependence plots using the implementation by Skater. The 1-feature plots are
Skater partial dependacy plot for LSTAT feature
Skater partial dependacy plot for number of rooms feature
This can be generation using code like
from skater.core.explanations import Interpretation from skater.model import InMemoryModel from matplotlib import pyplot as plt f, ax = plt.subplots(1, 1, figsize = (26, 18)) interpreter = Interpretation(boston_X, feature_names=boston_X.columns.values) pyint_model = InMemoryModel(ran_for_reg.predict, examples=boston_X) features = ['LSTAT', 'RM'] interpreter.partial_dependence.plot_partial_dependence( features, pyint_model )
and the 2-feature plot is
Dual partial dependacy plot for LSTAT and RM features
This can be generated by almost identical code through changing
interpreter.partial_dependence.plot_partial_dependence( features, pyint_model )
to
interpreter.partial_dependence.plot_partial_dependence( [features], # max length 2 pyint_model )
The interpretation of the features and this model through both packages are the same.
Surrogate Models
The idea of surrogate models is to use an interpretable model that is trained to approximate the predictions of a black box model. If the approximation is close enough, you can draw conclusions about how the black box model makes decisions by using the intrinsic interpretability of the surrogate model.
Skater - Global
We can use the tree_surrogate method from the Skater package to train a Decision Tree on the outputs of the base model. Here I train a tree based on the outputs of the XGBoost classifier built on the Iris flower dataset. I can generate a very good decision tree as an approximator with it achieving a 0.993 f-1 score on the XGBoost output with the following code.
interpreter = Interpretation(iris_X, feature_names=iris_X.columns.values) pyint_model = InMemoryModel(xgb_clf.predict, examples=iris_X, unique_values=iris_y['flower'].unique()) surrogate_explainer = interpreter.tree_surrogate(oracle=pyint_model, seed=42) surrogate_explainer.fit(iris_X, iris_y, use_oracle=True, scorer_type='f1')
I visualise this tree with
from skater.util.dataops import show_in_notebook from graphviz import Source from IPython.display import SVG graph = Source(surrogate_explainer.plot_global_decisions(file_name='test_tree_pre_iris.png').to_string()) svg_data = graph.pipe(format='svg') with open('dtree_structure_iris.svg','wb') as f: f.write(svg_data) SVG(svg_data)
Surrogate decision tree
You can use this output to determine the decision boundaries that will be very similar to those in the initial black box model and help to explain why predictions are made given the input feature set.
LIME - Local
The method above gives a good explanation for the model from a top-down perspective. For a local explanation using surrogate models, you can use the LIME (Local Interpretable Model-agnostic Explanation) method proposed in a paper by Ribeiro, Singh, Guestrin.
LIME works in a process described the following steps:
-
1. Choosing an input datapoint for which it will try to explain.
-
2. Generating additional input data around this datapoint using the distributions from input feature space. The
black box algorithm
is then applied to this new input data and the results stored.
-
3. Weight the new samples according to their proximity to the instance of interest.
-
4. Train a weighted, interpretable model on the dataset with the variations.
-
5. Explain the prediction by interpreting the local model.
This means that the surrogate model that LIME builds, only needs to accurately reflect the black box model near the point of interest, and hence is free of the sometimes heavy restriction of being an accurate global surrogate.
In the following code, I use an implementation of LIME in the Skater package on the Iris dataset.
from skater.core.local_interpretation.lime.lime_tabular import LimeTabularExplainer exp = LimeTabularExplainer(iris_X.values, feature_names=list(iris_X.columns), discretize_continuous=True, class_names=[0, 1, 2]) doc_num = 50 print('Actual Label:', iris_y.iloc[doc_num].values) print('Predicted Label:', xgb_clf.predict(iris_X.iloc[doc_num])) exp.explain_instance(iris_X.iloc[doc_num], xgb_clf.predict_proba).show_in_notebook()
This prints
Local interpretability
This output shows us the petal length being 4.35 is one of the largest contributors to predicting this datapoint with a value of 1. You can also see which features and their values suggest alternative classes.
Conclusion
In this post I have shown a few methods you can use to begin to humanise the data science process and explain simply to stakeholders why complex, hard-to-interpet models make the decisions they do. These techniques can help to bridge the gap between simple, business-focused problem solving and mathematically advanced computation techniques, and allow you to explain or validate fundamental connections of natural phenomena when represented by data.