Skip to content

Training API Reference

The training module provides the core training infrastructure, including the Trainer class (handling mixed precision and early stopping), a registry for loss functions (Huber, MSE, Physics-Informed), and evaluation metrics tailored for SOH prediction.

trainer

Training loop with AMP, gradient clipping, early stopping.

Classes

Trainer

Trainer(model: nn.Module, config: Dict[str, Any], tracker: Optional[BaseTracker] = None, device: str = 'auto', loss_config: Optional[Dict[str, Any]] = None, verbose: bool = False)

Training loop with AMP, gradient clipping, early stopping.

Features: - Automatic mixed precision (AMP) for faster training on GPU - Gradient clipping for stability - CosineAnnealing learning rate schedule - Early stopping with patience - Automatic Sample → DataLoader conversion

Example usage

trainer = Trainer(model, config, tracker) history = trainer.fit(train_samples, val_samples)

Initialize the trainer.

Parameters:

Name Type Description Default
model Module

PyTorch model

required
config Dict[str, Any]

Training configuration

required
tracker Optional[BaseTracker]

Experiment tracker

None
device str

Device to use ('auto', 'cuda', 'cpu')

'auto'
loss_config Optional[Dict[str, Any]]

Loss function configuration dict with 'name' key and loss-specific parameters. If None, defaults to MSE.

None
verbose bool

If True, print training progress every epoch

False
Source code in src/training/trainer.py
def __init__(self,
             model: nn.Module,
             config: Dict[str, Any],
             tracker: Optional[BaseTracker] = None,
             device: str = 'auto',
             loss_config: Optional[Dict[str, Any]] = None,
             verbose: bool = False):
    """Initialize the trainer.

    Args:
        model: PyTorch model
        config: Training configuration
        tracker: Experiment tracker
        device: Device to use ('auto', 'cuda', 'cpu')
        loss_config: Loss function configuration dict with 'name' key
                     and loss-specific parameters. If None, defaults to MSE.
        verbose: If True, print training progress every epoch
    """
    # Determine device
    if device == 'auto':
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    self.device = device

    self.model = model.to(device)
    self.config = config
    self.tracker = tracker
    self.verbose = verbose

    # Loss - get from config or default to MSE
    if loss_config:
        loss_name = loss_config.get('name', 'mse')
        loss_params = {k: v for k, v in loss_config.items() if k != 'name'}
        self.criterion = LossRegistry.get(loss_name, **loss_params)
        logger.info(f"Using loss: {loss_name} with params: {loss_params}")
    else:
        self.criterion = LossRegistry.get('mse')

    # Optimizer
    self.optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config.get('learning_rate', 1e-3),
        weight_decay=config.get('weight_decay', 0.01),
    )

    # Scheduler
    self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        self.optimizer,
        T_0=config.get('scheduler_T0', 50),
        T_mult=2,
    )

    # AMP
    self.use_amp = (
        config.get('use_amp', True) and 
        device == 'cuda' and 
        HAS_AMP
    )
    if self.use_amp:
        # Handle different PyTorch AMP API versions
        self.scaler = GradScaler('cuda') if AMP_NEW_API else GradScaler()
    else:
        self.scaler = None

    self.grad_clip = config.get('gradient_clip', 1.0)

    # Tracking
    self.best_val_loss = float('inf')
    self.best_state = None
    self.patience_counter = 0

    logger.info(f"Trainer initialized: device={device}, AMP={self.use_amp}")
Functions
fit
fit(train_samples: List[Sample], val_samples: List[Sample], epochs: Optional[int] = None, patience: Optional[int] = None) -> Dict[str, List[float]]

Train the model.

Parameters:

Name Type Description Default
train_samples List[Sample]

Training samples

required
val_samples List[Sample]

Validation samples

required
epochs Optional[int]

Number of epochs (default from config)

None
patience Optional[int]

Early stopping patience (default from config)

None

Returns:

Type Description
Dict[str, List[float]]

Training history with train_loss and val_loss

Source code in src/training/trainer.py
def fit(self,
        train_samples: List[Sample],
        val_samples: List[Sample],
        epochs: Optional[int] = None,
        patience: Optional[int] = None) -> Dict[str, List[float]]:
    """Train the model.

    Args:
        train_samples: Training samples
        val_samples: Validation samples
        epochs: Number of epochs (default from config)
        patience: Early stopping patience (default from config)

    Returns:
        Training history with train_loss and val_loss
    """
    epochs = epochs or self.config.get('epochs', 100)
    patience = patience or self.config.get('early_stopping_patience', 20)
    batch_size = self.config.get('batch_size', 32)

    train_loader = self._samples_to_loader(train_samples, batch_size, shuffle=True)
    val_loader = self._samples_to_loader(val_samples, batch_size, shuffle=False)

    history = {'train_loss': [], 'val_loss': [], 'lr': []}

    logger.info(f"Starting training: {epochs} epochs, batch_size={batch_size}")

    for epoch in range(epochs):
        # Train
        train_loss = self._train_epoch(train_loader)
        val_loss = self._validate(val_loader)

        current_lr = self.optimizer.param_groups[0]['lr']

        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['lr'].append(current_lr)

        # Log
        if self.tracker:
            self.tracker.log_metrics({
                'train_loss': train_loss,
                'val_loss': val_loss,
                'lr': current_lr,
            }, step=epoch)

        # Scheduler
        self.scheduler.step()

        # Early stopping
        if val_loss < self.best_val_loss:
            self.best_val_loss = val_loss
            self.best_state = {k: v.cpu().clone() for k, v in self.model.state_dict().items()}
            self.patience_counter = 0
        else:
            self.patience_counter += 1

        # Print progress
        if self.verbose:
            # Print every epoch if verbose, updating the same line
            status = f"  Epoch {epoch+1:3d}/{epochs} | Train Loss: {train_loss:.5f} | Val Loss: {val_loss:.5f} | LR: {current_lr:.2e}"
            if val_loss < self.best_val_loss:
                status += " ✓ (best)"
            else:
                status += f" (patience: {self.patience_counter}/{patience})"
            print(f"\r{status}", end='', flush=True)
        elif (epoch + 1) % 10 == 0:
            logger.info(f"Epoch {epoch+1}/{epochs} - Train: {train_loss:.5f}, Val: {val_loss:.5f}, LR: {current_lr:.2e}")

        if self.patience_counter >= patience:
            logger.info(f"Early stopping at epoch {epoch + 1}")
            break

    # Print newline after training completes if verbose was enabled
    if self.verbose:
        print()  # Newline after final status update

    # Restore best
    if self.best_state:
        self.model.load_state_dict(self.best_state)
        logger.info(f"Restored best model with val_loss={self.best_val_loss:.5f}")

    return history
load
load(path: Path) -> None

Load model and optimizer state.

Parameters:

Name Type Description Default
path Path

Path to checkpoint

required
Source code in src/training/trainer.py
def load(self, path: Path) -> None:
    """Load model and optimizer state.

    Args:
        path: Path to checkpoint
    """
    checkpoint = torch.load(path, map_location=self.device)
    self.model.load_state_dict(checkpoint['model_state'])
    self.optimizer.load_state_dict(checkpoint['optimizer_state'])
    if 'scheduler_state' in checkpoint:
        self.scheduler.load_state_dict(checkpoint['scheduler_state'])
    self.best_val_loss = checkpoint.get('best_val_loss', float('inf'))
    logger.info(f"Loaded checkpoint from {path}")
predict
predict(samples: List[Sample]) -> np.ndarray

Get predictions for samples.

Parameters:

Name Type Description Default
samples List[Sample]

List of samples

required

Returns:

Type Description
ndarray

Numpy array of predictions

Source code in src/training/trainer.py
@torch.no_grad()
def predict(self, samples: List[Sample]) -> np.ndarray:
    """Get predictions for samples.

    Args:
        samples: List of samples

    Returns:
        Numpy array of predictions
    """
    self.model.eval()

    batch_size = self.config.get('batch_size', 32)
    loader = self._samples_to_loader(samples, batch_size, shuffle=False)

    predictions = []
    for batch in loader:
        if len(batch) == 3:
            X, _, t = [b.to(self.device) for b in batch]
        else:
            X, _ = [b.to(self.device) for b in batch]
            t = None

        pred = self.model(X, t=t)
        predictions.append(pred.cpu().numpy())

    return np.vstack(predictions)
save
save(path: Path) -> None

Save model and optimizer state.

Parameters:

Name Type Description Default
path Path

Path to save checkpoint

required
Source code in src/training/trainer.py
def save(self, path: Path) -> None:
    """Save model and optimizer state.

    Args:
        path: Path to save checkpoint
    """
    torch.save({
        'model_state': self.model.state_dict(),
        'optimizer_state': self.optimizer.state_dict(),
        'scheduler_state': self.scheduler.state_dict(),
        'best_val_loss': self.best_val_loss,
        'config': self.config,
    }, path)
    logger.info(f"Saved checkpoint to {path}")

metrics

Evaluation metrics for battery degradation prediction.

Classes

Functions

compute_metrics

compute_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> Dict[str, float]

Compute regression metrics.

Parameters:

Name Type Description Default
y_true ndarray

Ground truth values

required
y_pred ndarray

Predicted values

required

Returns:

Type Description
Dict[str, float]

Dictionary with RMSE, MAE, MAPE, R² metrics

Source code in src/training/metrics.py
def compute_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> Dict[str, float]:
    """Compute regression metrics.

    Args:
        y_true: Ground truth values
        y_pred: Predicted values

    Returns:
        Dictionary with RMSE, MAE, MAPE, R² metrics
    """
    y_true = np.asarray(y_true).flatten()
    y_pred = np.asarray(y_pred).flatten()

    # RMSE
    rmse = np.sqrt(np.mean((y_true - y_pred) ** 2))

    # MAE
    mae = np.mean(np.abs(y_true - y_pred))

    # MAPE (avoid division by zero)
    mask = y_true != 0
    if mask.sum() > 0:
        mape = np.mean(np.abs((y_true[mask] - y_pred[mask]) / y_true[mask])) * 100
    else:
        mape = 0.0

    # R² (coefficient of determination)
    ss_res = np.sum((y_true - y_pred) ** 2)
    ss_tot = np.sum((y_true - np.mean(y_true)) ** 2)
    r2 = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0

    # Max absolute error
    max_ae = np.max(np.abs(y_true - y_pred))

    return {
        'rmse': float(rmse),
        'mae': float(mae),
        'mape': float(mape),
        'r2': float(r2),
        'max_ae': float(max_ae),
    }

evaluate_by_group

evaluate_by_group(samples: List[Sample], predictions: np.ndarray, group_key: str = 'temperature_C') -> Dict[str, Dict[str, float]]

Compute metrics per group (e.g., per temperature).

Parameters:

Name Type Description Default
samples List[Sample]

List of Sample objects

required
predictions ndarray

Predictions array

required
group_key str

Meta key to group by

'temperature_C'

Returns:

Type Description
Dict[str, Dict[str, float]]

Dictionary mapping group values to metrics

Source code in src/training/metrics.py
def evaluate_by_group(samples: List[Sample], predictions: np.ndarray,
                      group_key: str = 'temperature_C') -> Dict[str, Dict[str, float]]:
    """Compute metrics per group (e.g., per temperature).

    Args:
        samples: List of Sample objects
        predictions: Predictions array
        group_key: Meta key to group by

    Returns:
        Dictionary mapping group values to metrics
    """
    groups = defaultdict(lambda: {'y_true': [], 'y_pred': []})

    predictions = np.asarray(predictions).flatten()

    for i, sample in enumerate(samples):
        group = sample.meta.get(group_key, 'unknown')
        y_true = sample.y.numpy().item() if hasattr(sample.y, 'numpy') else float(sample.y)
        y_pred = predictions[i] if i < len(predictions) else 0

        groups[group]['y_true'].append(y_true)
        groups[group]['y_pred'].append(y_pred)

    results = {}
    for group, data in groups.items():
        y_true = np.array(data['y_true'])
        y_pred = np.array(data['y_pred'])
        results[str(group)] = compute_metrics(y_true, y_pred)

    return results

print_grouped_metrics

print_grouped_metrics(grouped_metrics: Dict[str, Dict[str, float]], group_name: str = 'Group') -> None

Print metrics for each group.

Parameters:

Name Type Description Default
grouped_metrics Dict[str, Dict[str, float]]

Dictionary of group -> metrics

required
group_name str

Name for the grouping variable

'Group'
Source code in src/training/metrics.py
def print_grouped_metrics(grouped_metrics: Dict[str, Dict[str, float]], 
                           group_name: str = "Group") -> None:
    """Print metrics for each group.

    Args:
        grouped_metrics: Dictionary of group -> metrics
        group_name: Name for the grouping variable
    """
    for group, metrics in sorted(grouped_metrics.items()):
        print(f"\n{group_name} = {group}:")
        print(f"  RMSE: {metrics['rmse']:.5f}, MAE: {metrics['mae']:.5f}, R²: {metrics['r2']:.4f}")

print_metrics

print_metrics(metrics: Dict[str, float], prefix: str = '') -> None

Print metrics in a formatted way.

Parameters:

Name Type Description Default
metrics Dict[str, float]

Metrics dictionary

required
prefix str

Optional prefix for lines

''
Source code in src/training/metrics.py
def print_metrics(metrics: Dict[str, float], prefix: str = "") -> None:
    """Print metrics in a formatted way.

    Args:
        metrics: Metrics dictionary
        prefix: Optional prefix for lines
    """
    if prefix:
        print(f"\n=== {prefix} ===")

    print(f"  RMSE: {metrics['rmse']:.5f}")
    print(f"  MAE:  {metrics['mae']:.5f}")
    print(f"  MAPE: {metrics['mape']:.2f}%")
    print(f"  R²:   {metrics['r2']:.4f}")

losses

Physics-informed losses for battery degradation models.

This module provides a registry-based loss function system that allows configuration-driven selection of loss functions.

Example usage

from src.training.losses import LossRegistry loss = LossRegistry.get('physics_informed', monotonicity_weight=0.1) computed_loss = loss(pred, target)

Classes

BaseLoss

Bases: Module, ABC

Abstract base class for all loss functions.

All loss functions must: 1. Inherit from BaseLoss 2. Implement forward(pred, target, t=None) 3. Be registered with @LossRegistry.register()

Example

@LossRegistry.register("custom_loss") ... class CustomLoss(BaseLoss): ... def forward(self, pred, target, t=None): ... return torch.mean((pred - target) ** 2)

Functions
forward abstractmethod
forward(pred: torch.Tensor, target: torch.Tensor, t: Optional[torch.Tensor] = None) -> torch.Tensor

Compute loss.

Parameters:

Name Type Description Default
pred Tensor

Predictions

required
target Tensor

Ground truth

required
t Optional[Tensor]

Optional time values for sequence data

None

Returns:

Type Description
Tensor

Loss value

Source code in src/training/losses.py
@abstractmethod
def forward(self, 
            pred: torch.Tensor, 
            target: torch.Tensor,
            t: Optional[torch.Tensor] = None) -> torch.Tensor:
    """Compute loss.

    Args:
        pred: Predictions
        target: Ground truth
        t: Optional time values for sequence data

    Returns:
        Loss value
    """
    pass

HuberLoss

HuberLoss(delta: float = 1.0, reduction: str = 'mean')

Bases: BaseLoss

Huber loss (less sensitive to outliers than MSE).

Uses L2 loss for small errors and L1 loss for large errors, providing robustness to outliers while maintaining smooth gradients.

Example usage

loss_fn = HuberLoss(delta=1.0) loss = loss_fn(pred, target)

Initialize Huber loss.

Parameters:

Name Type Description Default
delta float

Threshold for switching between L1 and L2

1.0
reduction str

Reduction method

'mean'
Source code in src/training/losses.py
def __init__(self, delta: float = 1.0, reduction: str = 'mean'):
    """Initialize Huber loss.

    Args:
        delta: Threshold for switching between L1 and L2
        reduction: Reduction method
    """
    super().__init__()
    self.delta = delta
    self.reduction = reduction
Functions
forward
forward(pred: torch.Tensor, target: torch.Tensor, t: Optional[torch.Tensor] = None) -> torch.Tensor

Compute Huber loss.

Source code in src/training/losses.py
def forward(self, 
            pred: torch.Tensor, 
            target: torch.Tensor,
            t: Optional[torch.Tensor] = None) -> torch.Tensor:
    """Compute Huber loss."""
    diff = torch.abs(pred - target)

    # L2 for small errors, L1 for large errors
    loss = torch.where(
        diff < self.delta,
        0.5 * diff ** 2,
        self.delta * (diff - 0.5 * self.delta)
    )

    if self.reduction == 'mean':
        return loss.mean()
    elif self.reduction == 'sum':
        return loss.sum()
    return loss

LossRegistry

Registry for loss function classes.

Example usage

@LossRegistry.register("mse") ... class MSELoss(BaseLoss): ... pass

loss = LossRegistry.get("mse", reduction='mean') print(LossRegistry.list_available()) ['mse', 'physics_informed', 'huber', 'mape']

Functions
get classmethod
get(name: str, **kwargs) -> BaseLoss

Get a loss function instance by name.

Parameters:

Name Type Description Default
name str

Registry name of the loss function

required
**kwargs

Arguments to pass to loss constructor

{}

Returns:

Type Description
BaseLoss

Loss function instance

Raises:

Type Description
ValueError

If loss name is not found

Source code in src/training/losses.py
@classmethod
def get(cls, name: str, **kwargs) -> BaseLoss:
    """Get a loss function instance by name.

    Args:
        name: Registry name of the loss function
        **kwargs: Arguments to pass to loss constructor

    Returns:
        Loss function instance

    Raises:
        ValueError: If loss name is not found
    """
    if name not in cls._losses:
        available = list(cls._losses.keys())
        raise ValueError(f"Unknown loss: {name}. Available: {available}")
    return cls._losses[name](**kwargs)
get_class classmethod
get_class(name: str) -> Optional[Type[BaseLoss]]

Get loss class by name (without instantiating).

Parameters:

Name Type Description Default
name str

Registry name of the loss function

required

Returns:

Type Description
Optional[Type[BaseLoss]]

Loss class or None if not found

Source code in src/training/losses.py
@classmethod
def get_class(cls, name: str) -> Optional[Type[BaseLoss]]:
    """Get loss class by name (without instantiating).

    Args:
        name: Registry name of the loss function

    Returns:
        Loss class or None if not found
    """
    return cls._losses.get(name)
list_available classmethod
list_available() -> list

List all registered loss function names.

Returns:

Type Description
list

List of loss function names

Source code in src/training/losses.py
@classmethod
def list_available(cls) -> list:
    """List all registered loss function names.

    Returns:
        List of loss function names
    """
    return list(cls._losses.keys())
register classmethod
register(name: str)

Decorator to register a loss function class.

Parameters:

Name Type Description Default
name str

Registry name for the loss function

required

Returns:

Type Description

Decorator function

Source code in src/training/losses.py
@classmethod
def register(cls, name: str):
    """Decorator to register a loss function class.

    Args:
        name: Registry name for the loss function

    Returns:
        Decorator function
    """
    def decorator(loss_class: Type[BaseLoss]):
        cls._losses[name] = loss_class
        loss_class.name = name
        return loss_class
    return decorator

MAELoss

MAELoss(reduction: str = 'mean')

Bases: BaseLoss

Mean Absolute Error loss (L1 loss).

More robust to outliers than MSE, but has non-smooth gradients at zero.

Example usage

loss_fn = MAELoss() loss = loss_fn(pred, target)

Initialize MAE loss.

Parameters:

Name Type Description Default
reduction str

Reduction method ('mean', 'sum', 'none')

'mean'
Source code in src/training/losses.py
def __init__(self, reduction: str = 'mean'):
    """Initialize MAE loss.

    Args:
        reduction: Reduction method ('mean', 'sum', 'none')
    """
    super().__init__()
    self.reduction = reduction
    self.l1 = nn.L1Loss(reduction=reduction)
Functions
forward
forward(pred: torch.Tensor, target: torch.Tensor, t: Optional[torch.Tensor] = None) -> torch.Tensor

Compute MAE loss.

Source code in src/training/losses.py
def forward(self, 
            pred: torch.Tensor, 
            target: torch.Tensor,
            t: Optional[torch.Tensor] = None) -> torch.Tensor:
    """Compute MAE loss."""
    return self.l1(pred, target)

MSELoss

MSELoss(reduction: str = 'mean')

Bases: BaseLoss

Standard Mean Squared Error loss.

Example usage

loss_fn = MSELoss() loss = loss_fn(pred, target)

Initialize MSE loss.

Parameters:

Name Type Description Default
reduction str

Reduction method ('mean', 'sum', 'none')

'mean'
Source code in src/training/losses.py
def __init__(self, reduction: str = 'mean'):
    """Initialize MSE loss.

    Args:
        reduction: Reduction method ('mean', 'sum', 'none')
    """
    super().__init__()
    self.reduction = reduction
    self.mse = nn.MSELoss(reduction=reduction)
Functions
forward
forward(pred: torch.Tensor, target: torch.Tensor, t: Optional[torch.Tensor] = None) -> torch.Tensor

Compute MSE loss.

Source code in src/training/losses.py
def forward(self, 
            pred: torch.Tensor, 
            target: torch.Tensor,
            t: Optional[torch.Tensor] = None) -> torch.Tensor:
    """Compute MSE loss."""
    return self.mse(pred, target)

PercentageLoss

PercentageLoss(epsilon: float = 1e-08, reduction: str = 'mean')

Bases: BaseLoss

Mean Absolute Percentage Error loss.

Useful when relative errors are more important than absolute errors. Returns loss as a percentage (0-100 scale).

Example usage

loss_fn = PercentageLoss() loss = loss_fn(pred, target) # Returns MAPE as percentage

Initialize MAPE loss.

Parameters:

Name Type Description Default
epsilon float

Small value to prevent division by zero

1e-08
reduction str

Reduction method

'mean'
Source code in src/training/losses.py
def __init__(self, epsilon: float = 1e-8, reduction: str = 'mean'):
    """Initialize MAPE loss.

    Args:
        epsilon: Small value to prevent division by zero
        reduction: Reduction method
    """
    super().__init__()
    self.epsilon = epsilon
    self.reduction = reduction
Functions
forward
forward(pred: torch.Tensor, target: torch.Tensor, t: Optional[torch.Tensor] = None) -> torch.Tensor

Compute MAPE loss.

Source code in src/training/losses.py
def forward(self, 
            pred: torch.Tensor, 
            target: torch.Tensor,
            t: Optional[torch.Tensor] = None) -> torch.Tensor:
    """Compute MAPE loss."""
    loss = torch.abs((target - pred) / (target.abs() + self.epsilon))

    if self.reduction == 'mean':
        return loss.mean() * 100
    elif self.reduction == 'sum':
        return loss.sum() * 100
    return loss * 100

PhysicsInformedLoss

PhysicsInformedLoss(monotonicity_weight: float = 0.0, smoothness_weight: float = 0.0, reduction: str = 'mean')

Bases: BaseLoss

MSE loss with optional physics-based regularization.

Regularization terms: - Monotonicity: SOH should generally decrease over time - Smoothness: Predictions should be smooth - Arrhenius consistency: Temperature effects should follow Arrhenius

Example usage

loss_fn = PhysicsInformedLoss(monotonicity_weight=0.1) loss = loss_fn(pred, target, t=times)

Initialize the loss.

Parameters:

Name Type Description Default
monotonicity_weight float

Weight for monotonicity regularization

0.0
smoothness_weight float

Weight for smoothness regularization

0.0
reduction str

Reduction method ('mean', 'sum', 'none')

'mean'
Source code in src/training/losses.py
def __init__(self,
             monotonicity_weight: float = 0.0,
             smoothness_weight: float = 0.0,
             reduction: str = 'mean'):
    """Initialize the loss.

    Args:
        monotonicity_weight: Weight for monotonicity regularization
        smoothness_weight: Weight for smoothness regularization
        reduction: Reduction method ('mean', 'sum', 'none')
    """
    super().__init__()
    self.monotonicity_weight = monotonicity_weight
    self.smoothness_weight = smoothness_weight
    self.reduction = reduction
    self.mse = nn.MSELoss(reduction=reduction)
Functions
forward
forward(pred: torch.Tensor, target: torch.Tensor, t: Optional[torch.Tensor] = None) -> torch.Tensor

Compute loss.

Parameters:

Name Type Description Default
pred Tensor

Predictions

required
target Tensor

Ground truth

required
t Optional[Tensor]

Optional time values for sequence data

None

Returns:

Type Description
Tensor

Loss value

Source code in src/training/losses.py
def forward(self, 
            pred: torch.Tensor, 
            target: torch.Tensor,
            t: Optional[torch.Tensor] = None) -> torch.Tensor:
    """Compute loss.

    Args:
        pred: Predictions
        target: Ground truth
        t: Optional time values for sequence data

    Returns:
        Loss value
    """
    # Base MSE loss
    loss = self.mse(pred, target)

    # Monotonicity regularization (for sequences)
    if self.monotonicity_weight > 0 and pred.dim() >= 2:
        # Penalize increases in SOH over time
        diff = pred[:, 1:] - pred[:, :-1]
        monotonicity_penalty = torch.relu(diff).mean()
        loss = loss + self.monotonicity_weight * monotonicity_penalty

    # Smoothness regularization
    if self.smoothness_weight > 0 and pred.dim() >= 2:
        # Second derivative should be small
        diff2 = pred[:, 2:] - 2 * pred[:, 1:-1] + pred[:, :-2]
        smoothness_penalty = (diff2 ** 2).mean()
        loss = loss + self.smoothness_weight * smoothness_penalty

    return loss

callbacks

Training callbacks for monitoring and checkpointing.

Classes

Callback

Base callback class.

Functions
on_epoch_end
on_epoch_end(epoch: int, logs: Dict[str, Any] = None) -> None

Called at end of each epoch.

Source code in src/training/callbacks.py
def on_epoch_end(self, epoch: int, logs: Dict[str, Any] = None) -> None:
    """Called at end of each epoch."""
    pass
on_epoch_start
on_epoch_start(epoch: int, logs: Dict[str, Any] = None) -> None

Called at start of each epoch.

Source code in src/training/callbacks.py
def on_epoch_start(self, epoch: int, logs: Dict[str, Any] = None) -> None:
    """Called at start of each epoch."""
    pass
on_train_end
on_train_end(logs: Dict[str, Any] = None) -> None

Called at end of training.

Source code in src/training/callbacks.py
def on_train_end(self, logs: Dict[str, Any] = None) -> None:
    """Called at end of training."""
    pass
on_train_start
on_train_start(logs: Dict[str, Any] = None) -> None

Called at start of training.

Source code in src/training/callbacks.py
def on_train_start(self, logs: Dict[str, Any] = None) -> None:
    """Called at start of training."""
    pass

EarlyStoppingCallback

EarlyStoppingCallback(monitor: str = 'val_loss', patience: int = 20, min_delta: float = 0.0, mode: str = 'min')

Bases: Callback

Stop training when a monitored metric has stopped improving.

Initialize early stopping.

Parameters:

Name Type Description Default
monitor str

Metric to monitor

'val_loss'
patience int

Number of epochs to wait

20
min_delta float

Minimum change to qualify as improvement

0.0
mode str

'min' or 'max'

'min'
Source code in src/training/callbacks.py
def __init__(self,
             monitor: str = 'val_loss',
             patience: int = 20,
             min_delta: float = 0.0,
             mode: str = 'min'):
    """Initialize early stopping.

    Args:
        monitor: Metric to monitor
        patience: Number of epochs to wait
        min_delta: Minimum change to qualify as improvement
        mode: 'min' or 'max'
    """
    self.monitor = monitor
    self.patience = patience
    self.min_delta = min_delta
    self.mode = mode

    self.best = float('inf') if mode == 'min' else float('-inf')
    self.counter = 0
    self.should_stop = False
Functions
on_epoch_end
on_epoch_end(epoch: int, logs: Dict[str, Any] = None) -> None

Check if training should stop.

Source code in src/training/callbacks.py
def on_epoch_end(self, epoch: int, logs: Dict[str, Any] = None) -> None:
    """Check if training should stop."""
    if logs is None or self.monitor not in logs:
        return

    current = logs[self.monitor]

    if self.mode == 'min':
        improved = current < self.best - self.min_delta
    else:
        improved = current > self.best + self.min_delta

    if improved:
        self.best = current
        self.counter = 0
    else:
        self.counter += 1
        if self.counter >= self.patience:
            self.should_stop = True
            logger.info(f"Early stopping triggered at epoch {epoch}")

LRSchedulerCallback

LRSchedulerCallback(scheduler, step_on: str = 'epoch')

Bases: Callback

Learning rate scheduler callback.

Initialize scheduler callback.

Parameters:

Name Type Description Default
scheduler

PyTorch scheduler

required
step_on str

When to step ('epoch' or 'batch')

'epoch'
Source code in src/training/callbacks.py
def __init__(self, scheduler, step_on: str = 'epoch'):
    """Initialize scheduler callback.

    Args:
        scheduler: PyTorch scheduler
        step_on: When to step ('epoch' or 'batch')
    """
    self.scheduler = scheduler
    self.step_on = step_on
Functions
on_epoch_end
on_epoch_end(epoch: int, logs: Dict[str, Any] = None) -> None

Step the scheduler.

Source code in src/training/callbacks.py
def on_epoch_end(self, epoch: int, logs: Dict[str, Any] = None) -> None:
    """Step the scheduler."""
    if self.step_on == 'epoch':
        self.scheduler.step()

ModelCheckpointCallback

ModelCheckpointCallback(save_path: Path, monitor: str = 'val_loss', mode: str = 'min', save_best_only: bool = True, save_fn: Optional[Callable] = None)

Bases: Callback

Save model checkpoints based on monitored metric.

Initialize checkpointing.

Parameters:

Name Type Description Default
save_path Path

Path to save checkpoint

required
monitor str

Metric to monitor

'val_loss'
mode str

'min' or 'max'

'min'
save_best_only bool

Only save when metric improves

True
save_fn Optional[Callable]

Custom save function

None
Source code in src/training/callbacks.py
def __init__(self,
             save_path: Path,
             monitor: str = 'val_loss',
             mode: str = 'min',
             save_best_only: bool = True,
             save_fn: Optional[Callable] = None):
    """Initialize checkpointing.

    Args:
        save_path: Path to save checkpoint
        monitor: Metric to monitor
        mode: 'min' or 'max'
        save_best_only: Only save when metric improves
        save_fn: Custom save function
    """
    self.save_path = Path(save_path)
    self.monitor = monitor
    self.mode = mode
    self.save_best_only = save_best_only
    self.save_fn = save_fn

    self.best = float('inf') if mode == 'min' else float('-inf')
Functions
on_epoch_end
on_epoch_end(epoch: int, logs: Dict[str, Any] = None) -> None

Save checkpoint if metric improved.

Source code in src/training/callbacks.py
def on_epoch_end(self, epoch: int, logs: Dict[str, Any] = None) -> None:
    """Save checkpoint if metric improved."""
    if logs is None or self.monitor not in logs:
        return

    current = logs[self.monitor]

    if self.mode == 'min':
        improved = current < self.best
    else:
        improved = current > self.best

    if improved or not self.save_best_only:
        self.best = current
        if self.save_fn:
            self.save_fn(self.save_path)
        logger.info(f"Checkpoint saved: {self.monitor}={current:.5f}")

ProgressCallback

ProgressCallback(print_every: int = 10)

Bases: Callback

Print training progress.

Initialize progress callback.

Parameters:

Name Type Description Default
print_every int

Print every N epochs

10
Source code in src/training/callbacks.py
def __init__(self, print_every: int = 10):
    """Initialize progress callback.

    Args:
        print_every: Print every N epochs
    """
    self.print_every = print_every
Functions
on_epoch_end
on_epoch_end(epoch: int, logs: Dict[str, Any] = None) -> None

Print progress.

Source code in src/training/callbacks.py
def on_epoch_end(self, epoch: int, logs: Dict[str, Any] = None) -> None:
    """Print progress."""
    if (epoch + 1) % self.print_every == 0 and logs:
        metrics_str = ", ".join(f"{k}: {v:.5f}" for k, v in logs.items())
        logger.info(f"Epoch {epoch + 1}: {metrics_str}")