File size: 4,048 Bytes
3efedb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import shap
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

class ShapExplainer:
    def __init__(self, model, feature_names=None):
        self.model = model
        self.explainer = None
        self.feature_names = feature_names
        
    def fit(self, X_background):
        """Fit the explainer with background data"""
        # Transform data if preprocessor exists
        if hasattr(self.model, 'preprocessor') and self.model.preprocessor is not None:
            X_processed = self.model.preprocessor.transform(X_background)
            # Convert to dense array if sparse
            if hasattr(X_processed, "toarray"):
                X_processed = X_processed.toarray()
        else:
            X_processed = X_background
            
        # Create explainer
        if hasattr(self.model, 'model'):
            self.explainer = shap.TreeExplainer(self.model.model)
        else:
            self.explainer = shap.TreeExplainer(self.model)
            
    def explain_instance(self, instance):
        """Generate SHAP values for a single instance"""
        # Transform instance if preprocessor exists
        if hasattr(self.model, 'preprocessor') and self.model.preprocessor is not None:
            instance_processed = self.model.preprocessor.transform(instance)
            # Convert to dense array if sparse
            if hasattr(instance_processed, "toarray"):
                instance_processed = instance_processed.toarray()
        else:
            instance_processed = instance
            
        # Calculate SHAP values
        shap_values = self.explainer.shap_values(instance_processed)
        
        # For classification models, shap_values might be a list of arrays
        if isinstance(shap_values, list):
            shap_values = shap_values[1]  # Assuming binary classification
            
        return shap_values
        
    def generate_explanation(self, instance, original_features=None):
        """Generate human-readable explanation"""
        shap_values = self.explain_instance(instance)
        
        # Get feature names
        if self.feature_names is None and hasattr(self.model, 'feature_names'):
            self.feature_names = self.model.feature_names
            
        # If we have the original feature names and values
        if original_features is not None:
            # Sort features by absolute SHAP value
            feature_importance = [(name, abs(shap_values[0][i])) 
                                 for i, name in enumerate(self.feature_names)]
            feature_importance.sort(key=lambda x: x[1], reverse=True)
            
            # Generate explanation text
            explanation = []
            for feature, importance in feature_importance[:5]:  # Top 5 features
                if importance > 0.01:  # Only include significant features
                    value = original_features[feature].values[0]
                    direction = "increased" if shap_values[0][self.feature_names.index(feature)] > 0 else "decreased"
                    explanation.append(f"{feature} = {value} {direction} the likelihood")
                    
            return " and ".join(explanation)
        
        return "Explanation requires original feature values"
        
    def plot_force(self, instance, matplotlib=True):
        """Generate force plot for an instance"""
        shap_values = self.explain_instance(instance)
        
        if matplotlib:
            shap.force_plot(self.explainer.expected_value, 
                           shap_values, 
                           instance, 
                           feature_names=self.feature_names,
                           matplotlib=True)
        else:
            return shap.force_plot(self.explainer.expected_value, 
                                  shap_values, 
                                  instance, 
                                  feature_names=self.feature_names)