WickedFaith's picture
Upload 77 files
3efedb0 verified
raw
history blame contribute delete
4.05 kB
import shap
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
class ShapExplainer:
def __init__(self, model, feature_names=None):
self.model = model
self.explainer = None
self.feature_names = feature_names
def fit(self, X_background):
"""Fit the explainer with background data"""
# Transform data if preprocessor exists
if hasattr(self.model, 'preprocessor') and self.model.preprocessor is not None:
X_processed = self.model.preprocessor.transform(X_background)
# Convert to dense array if sparse
if hasattr(X_processed, "toarray"):
X_processed = X_processed.toarray()
else:
X_processed = X_background
# Create explainer
if hasattr(self.model, 'model'):
self.explainer = shap.TreeExplainer(self.model.model)
else:
self.explainer = shap.TreeExplainer(self.model)
def explain_instance(self, instance):
"""Generate SHAP values for a single instance"""
# Transform instance if preprocessor exists
if hasattr(self.model, 'preprocessor') and self.model.preprocessor is not None:
instance_processed = self.model.preprocessor.transform(instance)
# Convert to dense array if sparse
if hasattr(instance_processed, "toarray"):
instance_processed = instance_processed.toarray()
else:
instance_processed = instance
# Calculate SHAP values
shap_values = self.explainer.shap_values(instance_processed)
# For classification models, shap_values might be a list of arrays
if isinstance(shap_values, list):
shap_values = shap_values[1] # Assuming binary classification
return shap_values
def generate_explanation(self, instance, original_features=None):
"""Generate human-readable explanation"""
shap_values = self.explain_instance(instance)
# Get feature names
if self.feature_names is None and hasattr(self.model, 'feature_names'):
self.feature_names = self.model.feature_names
# If we have the original feature names and values
if original_features is not None:
# Sort features by absolute SHAP value
feature_importance = [(name, abs(shap_values[0][i]))
for i, name in enumerate(self.feature_names)]
feature_importance.sort(key=lambda x: x[1], reverse=True)
# Generate explanation text
explanation = []
for feature, importance in feature_importance[:5]: # Top 5 features
if importance > 0.01: # Only include significant features
value = original_features[feature].values[0]
direction = "increased" if shap_values[0][self.feature_names.index(feature)] > 0 else "decreased"
explanation.append(f"{feature} = {value} {direction} the likelihood")
return " and ".join(explanation)
return "Explanation requires original feature values"
def plot_force(self, instance, matplotlib=True):
"""Generate force plot for an instance"""
shap_values = self.explain_instance(instance)
if matplotlib:
shap.force_plot(self.explainer.expected_value,
shap_values,
instance,
feature_names=self.feature_names,
matplotlib=True)
else:
return shap.force_plot(self.explainer.expected_value,
shap_values,
instance,
feature_names=self.feature_names)