Explainability API Reference¶
shap_analysis ¶
SHAP analysis for model interpretability.
Functions¶
compute_shap_values ¶
compute_shap_values(model, X: np.ndarray, feature_names: Optional[List[str]] = None, background_size: int = 100) -> Dict[str, Any]
Compute SHAP values for a model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model |
Trained model (LGBM, sklearn, or PyTorch) |
required | |
X |
ndarray
|
Input features array |
required |
feature_names |
Optional[List[str]]
|
Optional list of feature names |
None
|
background_size |
int
|
Size of background dataset for non-tree models |
100
|
Returns:
| Type | Description |
|---|---|
Dict[str, Any]
|
Dictionary with shap_values, expected_value, feature_names |
Source code in src/explainability/shap_analysis.py
get_feature_importance ¶
Get mean absolute SHAP values as feature importance.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
shap_result |
Dict[str, Any]
|
Result from compute_shap_values() |
required |
Returns:
| Type | Description |
|---|---|
Dict[str, float]
|
Dictionary mapping feature names to importance |
Source code in src/explainability/shap_analysis.py
plot_shap_summary ¶
plot_shap_summary(shap_result: Dict[str, Any], X: np.ndarray, max_display: int = 10, save_path: Optional[str] = None)
Plot SHAP summary.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
shap_result |
Dict[str, Any]
|
Result from compute_shap_values() |
required |
X |
ndarray
|
Feature data |
required |
max_display |
int
|
Maximum features to display |
10
|
save_path |
Optional[str]
|
Optional path to save figure |
None
|
Returns:
| Type | Description |
|---|---|
|
Matplotlib figure |
Source code in src/explainability/shap_analysis.py
plot_shap_waterfall ¶
plot_shap_waterfall(shap_result: Dict[str, Any], sample_idx: int = 0, save_path: Optional[str] = None)
Plot SHAP waterfall for a single prediction.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
shap_result |
Dict[str, Any]
|
Result from compute_shap_values() |
required |
sample_idx |
int
|
Index of sample to explain |
0
|
save_path |
Optional[str]
|
Optional path to save figure |
None
|
Returns:
| Type | Description |
|---|---|
|
Matplotlib figure |
Source code in src/explainability/shap_analysis.py
attention_viz ¶
Attention visualization for sequence models.
Functions¶
explain_model_attention ¶
Get attention explanations from a model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model |
Model with attention (LSTMAttentionModel) |
required | |
samples |
Input samples |
required | |
feature_names |
List[str]
|
Feature names |
None
|
Returns:
| Type | Description |
|---|---|
|
Dictionary with attention analysis |
Source code in src/explainability/attention_viz.py
plot_attention_over_time ¶
plot_attention_over_time(attention_weights: np.ndarray, time_values: Optional[np.ndarray] = None, focus_position: int = -1, title: str = 'Attention Focus Over Time', save_path: Optional[str] = None)
Plot how attention is distributed from a specific position.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
attention_weights |
ndarray
|
Attention weights |
required |
time_values |
Optional[ndarray]
|
Optional time values for x-axis |
None
|
focus_position |
int
|
Query position to visualize (-1 for last) |
-1
|
title |
str
|
Plot title |
'Attention Focus Over Time'
|
save_path |
Optional[str]
|
Optional path to save figure |
None
|
Returns:
| Type | Description |
|---|---|
|
Matplotlib figure |
Source code in src/explainability/attention_viz.py
plot_attention_weights ¶
plot_attention_weights(attention_weights: np.ndarray, x_labels: Optional[List[str]] = None, y_labels: Optional[List[str]] = None, title: str = 'Attention Weights', save_path: Optional[str] = None)
Plot attention weights as heatmap.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
attention_weights |
ndarray
|
Attention weights array of shape (seq_len, seq_len) or (heads, seq_len, seq_len) |
required |
x_labels |
Optional[List[str]]
|
Labels for x-axis (query positions) |
None
|
y_labels |
Optional[List[str]]
|
Labels for y-axis (key positions) |
None
|
title |
str
|
Plot title |
'Attention Weights'
|
save_path |
Optional[str]
|
Optional path to save figure |
None
|
Returns:
| Type | Description |
|---|---|
|
Matplotlib figure |