Spaces:
Sleeping
Sleeping
import shap | |
import numpy as np | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
class ShapExplainer: | |
def __init__(self, model, feature_names=None): | |
self.model = model | |
self.explainer = None | |
self.feature_names = feature_names | |
def fit(self, X_background): | |
"""Fit the explainer with background data""" | |
# Transform data if preprocessor exists | |
if hasattr(self.model, 'preprocessor') and self.model.preprocessor is not None: | |
X_processed = self.model.preprocessor.transform(X_background) | |
# Convert to dense array if sparse | |
if hasattr(X_processed, "toarray"): | |
X_processed = X_processed.toarray() | |
else: | |
X_processed = X_background | |
# Create explainer | |
if hasattr(self.model, 'model'): | |
self.explainer = shap.TreeExplainer(self.model.model) | |
else: | |
self.explainer = shap.TreeExplainer(self.model) | |
def explain_instance(self, instance): | |
"""Generate SHAP values for a single instance""" | |
# Transform instance if preprocessor exists | |
if hasattr(self.model, 'preprocessor') and self.model.preprocessor is not None: | |
instance_processed = self.model.preprocessor.transform(instance) | |
# Convert to dense array if sparse | |
if hasattr(instance_processed, "toarray"): | |
instance_processed = instance_processed.toarray() | |
else: | |
instance_processed = instance | |
# Calculate SHAP values | |
shap_values = self.explainer.shap_values(instance_processed) | |
# For classification models, shap_values might be a list of arrays | |
if isinstance(shap_values, list): | |
shap_values = shap_values[1] # Assuming binary classification | |
return shap_values | |
def generate_explanation(self, instance, original_features=None): | |
"""Generate human-readable explanation""" | |
shap_values = self.explain_instance(instance) | |
# Get feature names | |
if self.feature_names is None and hasattr(self.model, 'feature_names'): | |
self.feature_names = self.model.feature_names | |
# If we have the original feature names and values | |
if original_features is not None: | |
# Sort features by absolute SHAP value | |
feature_importance = [(name, abs(shap_values[0][i])) | |
for i, name in enumerate(self.feature_names)] | |
feature_importance.sort(key=lambda x: x[1], reverse=True) | |
# Generate explanation text | |
explanation = [] | |
for feature, importance in feature_importance[:5]: # Top 5 features | |
if importance > 0.01: # Only include significant features | |
value = original_features[feature].values[0] | |
direction = "increased" if shap_values[0][self.feature_names.index(feature)] > 0 else "decreased" | |
explanation.append(f"{feature} = {value} {direction} the likelihood") | |
return " and ".join(explanation) | |
return "Explanation requires original feature values" | |
def plot_force(self, instance, matplotlib=True): | |
"""Generate force plot for an instance""" | |
shap_values = self.explain_instance(instance) | |
if matplotlib: | |
shap.force_plot(self.explainer.expected_value, | |
shap_values, | |
instance, | |
feature_names=self.feature_names, | |
matplotlib=True) | |
else: | |
return shap.force_plot(self.explainer.expected_value, | |
shap_values, | |
instance, | |
feature_names=self.feature_names) |