Skip to content

Explainability API Reference

shap_analysis

SHAP analysis for model interpretability.

Functions

compute_shap_values

compute_shap_values(model, X: np.ndarray, feature_names: Optional[List[str]] = None, background_size: int = 100) -> Dict[str, Any]

Compute SHAP values for a model.

Parameters:

Name Type Description Default
model

Trained model (LGBM, sklearn, or PyTorch)

required
X ndarray

Input features array

required
feature_names Optional[List[str]]

Optional list of feature names

None
background_size int

Size of background dataset for non-tree models

100

Returns:

Type Description
Dict[str, Any]

Dictionary with shap_values, expected_value, feature_names

Source code in src/explainability/shap_analysis.py
def compute_shap_values(model, X: np.ndarray, 
                        feature_names: Optional[List[str]] = None,
                        background_size: int = 100) -> Dict[str, Any]:
    """Compute SHAP values for a model.

    Args:
        model: Trained model (LGBM, sklearn, or PyTorch)
        X: Input features array
        feature_names: Optional list of feature names
        background_size: Size of background dataset for non-tree models

    Returns:
        Dictionary with shap_values, expected_value, feature_names
    """
    if not HAS_SHAP:
        return {'error': 'SHAP not installed'}

    try:
        # Try TreeExplainer first (for LGBM, XGBoost, etc.)
        if hasattr(model, 'booster_') or hasattr(model, 'model'):
            # LGBM wrapper
            lgbm_model = getattr(model, 'model', model)
            explainer = shap.TreeExplainer(lgbm_model)
            shap_values = explainer.shap_values(X)
            expected_value = explainer.expected_value
        else:
            # Use KernelExplainer for other models
            background = X[:min(background_size, len(X))]

            # Create prediction function
            if hasattr(model, 'predict'):
                predict_fn = model.predict
            else:
                # Assume it's a callable
                predict_fn = lambda x: model(x).reshape(-1, 1)

            explainer = shap.KernelExplainer(predict_fn, background)
            shap_values = explainer.shap_values(X[:min(100, len(X))])
            expected_value = explainer.expected_value

        # Ensure feature_names matches the number of features
        if feature_names is None:
            feature_names = [f'f{i}' for i in range(X.shape[1])]
        elif len(feature_names) != X.shape[1]:
            logger.warning(f"Feature names count ({len(feature_names)}) doesn't match feature dimension ({X.shape[1]}). Using generic names.")
            feature_names = [f'f{i}' for i in range(X.shape[1])]

        # Handle multi-output shap_values (list of arrays)
        if isinstance(shap_values, list):
            # For multi-output, use first output or average
            shap_values = shap_values[0] if len(shap_values) > 0 else shap_values

        return {
            'shap_values': shap_values,
            'expected_value': expected_value,
            'feature_names': feature_names,
        }

    except Exception as e:
        logger.error(f"Failed to compute SHAP values: {e}")
        return {'error': str(e)}

get_feature_importance

get_feature_importance(shap_result: Dict[str, Any]) -> Dict[str, float]

Get mean absolute SHAP values as feature importance.

Parameters:

Name Type Description Default
shap_result Dict[str, Any]

Result from compute_shap_values()

required

Returns:

Type Description
Dict[str, float]

Dictionary mapping feature names to importance

Source code in src/explainability/shap_analysis.py
def get_feature_importance(shap_result: Dict[str, Any]) -> Dict[str, float]:
    """Get mean absolute SHAP values as feature importance.

    Args:
        shap_result: Result from compute_shap_values()

    Returns:
        Dictionary mapping feature names to importance
    """
    if 'error' in shap_result:
        return {}

    shap_values = shap_result['shap_values']
    feature_names = shap_result['feature_names']

    # Mean absolute SHAP value per feature
    importance = np.mean(np.abs(shap_values), axis=0)

    return {name: float(imp) for name, imp in zip(feature_names, importance)}

plot_shap_summary

plot_shap_summary(shap_result: Dict[str, Any], X: np.ndarray, max_display: int = 10, save_path: Optional[str] = None)

Plot SHAP summary.

Parameters:

Name Type Description Default
shap_result Dict[str, Any]

Result from compute_shap_values()

required
X ndarray

Feature data

required
max_display int

Maximum features to display

10
save_path Optional[str]

Optional path to save figure

None

Returns:

Type Description

Matplotlib figure

Source code in src/explainability/shap_analysis.py
def plot_shap_summary(shap_result: Dict[str, Any], 
                       X: np.ndarray,
                       max_display: int = 10,
                       save_path: Optional[str] = None):
    """Plot SHAP summary.

    Args:
        shap_result: Result from compute_shap_values()
        X: Feature data
        max_display: Maximum features to display
        save_path: Optional path to save figure

    Returns:
        Matplotlib figure
    """
    if not HAS_SHAP:
        logger.warning("SHAP not installed")
        return None

    if 'error' in shap_result:
        logger.warning(f"SHAP error: {shap_result['error']}")
        return None

    import matplotlib.pyplot as plt

    fig, ax = plt.subplots(figsize=(10, 6))

    shap_values = shap_result['shap_values']
    feature_names = shap_result['feature_names']

    # Ensure shap_values is 2D (n_samples, n_features) for proper summary plot
    # If 3D with singleton last dimension (n_samples, n_features, 1), squeeze it
    shap_values = np.array(shap_values)
    if shap_values.ndim == 3 and shap_values.shape[-1] == 1:
        shap_values = shap_values.squeeze(axis=-1)

    shap.summary_plot(
        shap_values, X,
        feature_names=feature_names,
        max_display=max_display,
        show=False
    )

    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        logger.info(f"Saved SHAP summary to {save_path}")

    return fig

plot_shap_waterfall

plot_shap_waterfall(shap_result: Dict[str, Any], sample_idx: int = 0, save_path: Optional[str] = None)

Plot SHAP waterfall for a single prediction.

Parameters:

Name Type Description Default
shap_result Dict[str, Any]

Result from compute_shap_values()

required
sample_idx int

Index of sample to explain

0
save_path Optional[str]

Optional path to save figure

None

Returns:

Type Description

Matplotlib figure

Source code in src/explainability/shap_analysis.py
def plot_shap_waterfall(shap_result: Dict[str, Any],
                         sample_idx: int = 0,
                         save_path: Optional[str] = None):
    """Plot SHAP waterfall for a single prediction.

    Args:
        shap_result: Result from compute_shap_values()
        sample_idx: Index of sample to explain
        save_path: Optional path to save figure

    Returns:
        Matplotlib figure
    """
    if not HAS_SHAP:
        return None

    if 'error' in shap_result:
        return None

    import matplotlib.pyplot as plt

    shap_values = shap_result['shap_values']
    expected_value = shap_result['expected_value']
    feature_names = shap_result['feature_names']

    # Create Explanation object
    explanation = shap.Explanation(
        values=shap_values[sample_idx],
        base_values=expected_value,
        feature_names=feature_names
    )

    fig, ax = plt.subplots(figsize=(10, 6))
    shap.waterfall_plot(explanation, show=False)

    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')

    return fig

attention_viz

Attention visualization for sequence models.

Functions

explain_model_attention

explain_model_attention(model, samples, feature_names: List[str] = None)

Get attention explanations from a model.

Parameters:

Name Type Description Default
model

Model with attention (LSTMAttentionModel)

required
samples

Input samples

required
feature_names List[str]

Feature names

None

Returns:

Type Description

Dictionary with attention analysis

Source code in src/explainability/attention_viz.py
def explain_model_attention(model, samples, feature_names: List[str] = None):
    """Get attention explanations from a model.

    Args:
        model: Model with attention (LSTMAttentionModel)
        samples: Input samples
        feature_names: Feature names

    Returns:
        Dictionary with attention analysis
    """
    import torch

    if not hasattr(model, 'explain'):
        return {'error': 'Model does not have explain method'}

    # Get sample input
    if hasattr(samples[0], 'x'):
        x = torch.stack([s.to_tensor().x for s in samples])
    else:
        x = samples

    # Get attention weights
    explanation = model.explain(x)

    if 'attention_weights' not in explanation:
        return {'error': 'No attention weights available'}

    attn = explanation['attention_weights']

    # Analyze attention patterns
    analysis = {
        'attention_weights': attn,
        'mean_attention_per_position': attn.mean(axis=(0, 1)),
        'attention_entropy': -np.sum(attn * np.log(attn + 1e-10), axis=-1).mean(),
    }

    return analysis

plot_attention_over_time

plot_attention_over_time(attention_weights: np.ndarray, time_values: Optional[np.ndarray] = None, focus_position: int = -1, title: str = 'Attention Focus Over Time', save_path: Optional[str] = None)

Plot how attention is distributed from a specific position.

Parameters:

Name Type Description Default
attention_weights ndarray

Attention weights

required
time_values Optional[ndarray]

Optional time values for x-axis

None
focus_position int

Query position to visualize (-1 for last)

-1
title str

Plot title

'Attention Focus Over Time'
save_path Optional[str]

Optional path to save figure

None

Returns:

Type Description

Matplotlib figure

Source code in src/explainability/attention_viz.py
def plot_attention_over_time(attention_weights: np.ndarray,
                              time_values: Optional[np.ndarray] = None,
                              focus_position: int = -1,
                              title: str = "Attention Focus Over Time",
                              save_path: Optional[str] = None):
    """Plot how attention is distributed from a specific position.

    Args:
        attention_weights: Attention weights
        time_values: Optional time values for x-axis
        focus_position: Query position to visualize (-1 for last)
        title: Plot title
        save_path: Optional path to save figure

    Returns:
        Matplotlib figure
    """
    import matplotlib.pyplot as plt

    if attention_weights.ndim == 3:
        attention_weights = attention_weights.mean(axis=0)

    weights = attention_weights[focus_position]
    seq_len = len(weights)

    if time_values is None:
        time_values = np.arange(seq_len)

    fig, ax = plt.subplots(figsize=(12, 4))

    ax.bar(time_values, weights, alpha=0.7, color='steelblue')
    ax.set_xlabel('Time')
    ax.set_ylabel('Attention Weight')
    ax.set_title(title)

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')

    return fig

plot_attention_weights

plot_attention_weights(attention_weights: np.ndarray, x_labels: Optional[List[str]] = None, y_labels: Optional[List[str]] = None, title: str = 'Attention Weights', save_path: Optional[str] = None)

Plot attention weights as heatmap.

Parameters:

Name Type Description Default
attention_weights ndarray

Attention weights array of shape (seq_len, seq_len) or (heads, seq_len, seq_len)

required
x_labels Optional[List[str]]

Labels for x-axis (query positions)

None
y_labels Optional[List[str]]

Labels for y-axis (key positions)

None
title str

Plot title

'Attention Weights'
save_path Optional[str]

Optional path to save figure

None

Returns:

Type Description

Matplotlib figure

Source code in src/explainability/attention_viz.py
def plot_attention_weights(attention_weights: np.ndarray,
                            x_labels: Optional[List[str]] = None,
                            y_labels: Optional[List[str]] = None,
                            title: str = "Attention Weights",
                            save_path: Optional[str] = None):
    """Plot attention weights as heatmap.

    Args:
        attention_weights: Attention weights array of shape (seq_len, seq_len) or (heads, seq_len, seq_len)
        x_labels: Labels for x-axis (query positions)
        y_labels: Labels for y-axis (key positions)
        title: Plot title
        save_path: Optional path to save figure

    Returns:
        Matplotlib figure
    """
    import matplotlib.pyplot as plt

    # Handle multi-head attention
    if attention_weights.ndim == 3:
        # Average across heads
        attention_weights = attention_weights.mean(axis=0)

    seq_len = attention_weights.shape[0]

    fig, ax = plt.subplots(figsize=(10, 8))

    im = ax.imshow(attention_weights, cmap='viridis', aspect='auto')

    # Add colorbar
    cbar = plt.colorbar(im, ax=ax)
    cbar.set_label('Attention Weight')

    # Labels
    if x_labels is None:
        x_labels = [f't{i}' for i in range(seq_len)]
    if y_labels is None:
        y_labels = x_labels

    ax.set_xticks(range(len(x_labels)))
    ax.set_yticks(range(len(y_labels)))
    ax.set_xticklabels(x_labels, rotation=45, ha='right')
    ax.set_yticklabels(y_labels)

    ax.set_xlabel('Key Position')
    ax.set_ylabel('Query Position')
    ax.set_title(title)

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        logger.info(f"Saved attention plot to {save_path}")

    return fig