Training API Reference¶
The training module provides the core training infrastructure, including the Trainer class (handling mixed precision and early stopping), a registry for loss functions (Huber, MSE, Physics-Informed), and evaluation metrics tailored for SOH prediction.
trainer ¶
Training loop with AMP, gradient clipping, early stopping.
Classes¶
Trainer ¶
Trainer(model: nn.Module, config: Dict[str, Any], tracker: Optional[BaseTracker] = None, device: str = 'auto', loss_config: Optional[Dict[str, Any]] = None, verbose: bool = False)
Training loop with AMP, gradient clipping, early stopping.
Features: - Automatic mixed precision (AMP) for faster training on GPU - Gradient clipping for stability - CosineAnnealing learning rate schedule - Early stopping with patience - Automatic Sample → DataLoader conversion
Example usage
trainer = Trainer(model, config, tracker) history = trainer.fit(train_samples, val_samples)
Initialize the trainer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model |
Module
|
PyTorch model |
required |
config |
Dict[str, Any]
|
Training configuration |
required |
tracker |
Optional[BaseTracker]
|
Experiment tracker |
None
|
device |
str
|
Device to use ('auto', 'cuda', 'cpu') |
'auto'
|
loss_config |
Optional[Dict[str, Any]]
|
Loss function configuration dict with 'name' key and loss-specific parameters. If None, defaults to MSE. |
None
|
verbose |
bool
|
If True, print training progress every epoch |
False
|
Source code in src/training/trainer.py
Functions¶
fit ¶
fit(train_samples: List[Sample], val_samples: List[Sample], epochs: Optional[int] = None, patience: Optional[int] = None) -> Dict[str, List[float]]
Train the model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
train_samples |
List[Sample]
|
Training samples |
required |
val_samples |
List[Sample]
|
Validation samples |
required |
epochs |
Optional[int]
|
Number of epochs (default from config) |
None
|
patience |
Optional[int]
|
Early stopping patience (default from config) |
None
|
Returns:
| Type | Description |
|---|---|
Dict[str, List[float]]
|
Training history with train_loss and val_loss |
Source code in src/training/trainer.py
164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 | |
load ¶
Load model and optimizer state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path |
Path
|
Path to checkpoint |
required |
Source code in src/training/trainer.py
predict ¶
Get predictions for samples.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
samples |
List[Sample]
|
List of samples |
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
Numpy array of predictions |
Source code in src/training/trainer.py
save ¶
Save model and optimizer state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path |
Path
|
Path to save checkpoint |
required |
Source code in src/training/trainer.py
metrics ¶
Evaluation metrics for battery degradation prediction.
Classes¶
Functions¶
compute_metrics ¶
Compute regression metrics.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y_true |
ndarray
|
Ground truth values |
required |
y_pred |
ndarray
|
Predicted values |
required |
Returns:
| Type | Description |
|---|---|
Dict[str, float]
|
Dictionary with RMSE, MAE, MAPE, R² metrics |
Source code in src/training/metrics.py
evaluate_by_group ¶
evaluate_by_group(samples: List[Sample], predictions: np.ndarray, group_key: str = 'temperature_C') -> Dict[str, Dict[str, float]]
Compute metrics per group (e.g., per temperature).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
samples |
List[Sample]
|
List of Sample objects |
required |
predictions |
ndarray
|
Predictions array |
required |
group_key |
str
|
Meta key to group by |
'temperature_C'
|
Returns:
| Type | Description |
|---|---|
Dict[str, Dict[str, float]]
|
Dictionary mapping group values to metrics |
Source code in src/training/metrics.py
print_grouped_metrics ¶
print_grouped_metrics(grouped_metrics: Dict[str, Dict[str, float]], group_name: str = 'Group') -> None
Print metrics for each group.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
grouped_metrics |
Dict[str, Dict[str, float]]
|
Dictionary of group -> metrics |
required |
group_name |
str
|
Name for the grouping variable |
'Group'
|
Source code in src/training/metrics.py
print_metrics ¶
Print metrics in a formatted way.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
metrics |
Dict[str, float]
|
Metrics dictionary |
required |
prefix |
str
|
Optional prefix for lines |
''
|
Source code in src/training/metrics.py
losses ¶
Physics-informed losses for battery degradation models.
This module provides a registry-based loss function system that allows configuration-driven selection of loss functions.
Example usage
from src.training.losses import LossRegistry loss = LossRegistry.get('physics_informed', monotonicity_weight=0.1) computed_loss = loss(pred, target)
Classes¶
BaseLoss ¶
Bases: Module, ABC
Abstract base class for all loss functions.
All loss functions must: 1. Inherit from BaseLoss 2. Implement forward(pred, target, t=None) 3. Be registered with @LossRegistry.register()
Example
@LossRegistry.register("custom_loss") ... class CustomLoss(BaseLoss): ... def forward(self, pred, target, t=None): ... return torch.mean((pred - target) ** 2)
Functions¶
forward
abstractmethod
¶
Compute loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pred |
Tensor
|
Predictions |
required |
target |
Tensor
|
Ground truth |
required |
t |
Optional[Tensor]
|
Optional time values for sequence data |
None
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Loss value |
Source code in src/training/losses.py
HuberLoss ¶
Bases: BaseLoss
Huber loss (less sensitive to outliers than MSE).
Uses L2 loss for small errors and L1 loss for large errors, providing robustness to outliers while maintaining smooth gradients.
Example usage
loss_fn = HuberLoss(delta=1.0) loss = loss_fn(pred, target)
Initialize Huber loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
delta |
float
|
Threshold for switching between L1 and L2 |
1.0
|
reduction |
str
|
Reduction method |
'mean'
|
Source code in src/training/losses.py
Functions¶
forward ¶
Compute Huber loss.
Source code in src/training/losses.py
LossRegistry ¶
Registry for loss function classes.
Example usage
@LossRegistry.register("mse") ... class MSELoss(BaseLoss): ... pass
loss = LossRegistry.get("mse", reduction='mean') print(LossRegistry.list_available()) ['mse', 'physics_informed', 'huber', 'mape']
Functions¶
get
classmethod
¶
Get a loss function instance by name.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name |
str
|
Registry name of the loss function |
required |
**kwargs |
Arguments to pass to loss constructor |
{}
|
Returns:
| Type | Description |
|---|---|
BaseLoss
|
Loss function instance |
Raises:
| Type | Description |
|---|---|
ValueError
|
If loss name is not found |
Source code in src/training/losses.py
get_class
classmethod
¶
Get loss class by name (without instantiating).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name |
str
|
Registry name of the loss function |
required |
Returns:
| Type | Description |
|---|---|
Optional[Type[BaseLoss]]
|
Loss class or None if not found |
Source code in src/training/losses.py
list_available
classmethod
¶
List all registered loss function names.
Returns:
| Type | Description |
|---|---|
list
|
List of loss function names |
register
classmethod
¶
Decorator to register a loss function class.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name |
str
|
Registry name for the loss function |
required |
Returns:
| Type | Description |
|---|---|
|
Decorator function |
Source code in src/training/losses.py
MAELoss ¶
Bases: BaseLoss
Mean Absolute Error loss (L1 loss).
More robust to outliers than MSE, but has non-smooth gradients at zero.
Example usage
loss_fn = MAELoss() loss = loss_fn(pred, target)
Initialize MAE loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
reduction |
str
|
Reduction method ('mean', 'sum', 'none') |
'mean'
|
Source code in src/training/losses.py
MSELoss ¶
PercentageLoss ¶
Bases: BaseLoss
Mean Absolute Percentage Error loss.
Useful when relative errors are more important than absolute errors. Returns loss as a percentage (0-100 scale).
Example usage
loss_fn = PercentageLoss() loss = loss_fn(pred, target) # Returns MAPE as percentage
Initialize MAPE loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
epsilon |
float
|
Small value to prevent division by zero |
1e-08
|
reduction |
str
|
Reduction method |
'mean'
|
Source code in src/training/losses.py
Functions¶
forward ¶
Compute MAPE loss.
Source code in src/training/losses.py
PhysicsInformedLoss ¶
PhysicsInformedLoss(monotonicity_weight: float = 0.0, smoothness_weight: float = 0.0, reduction: str = 'mean')
Bases: BaseLoss
MSE loss with optional physics-based regularization.
Regularization terms: - Monotonicity: SOH should generally decrease over time - Smoothness: Predictions should be smooth - Arrhenius consistency: Temperature effects should follow Arrhenius
Example usage
loss_fn = PhysicsInformedLoss(monotonicity_weight=0.1) loss = loss_fn(pred, target, t=times)
Initialize the loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
monotonicity_weight |
float
|
Weight for monotonicity regularization |
0.0
|
smoothness_weight |
float
|
Weight for smoothness regularization |
0.0
|
reduction |
str
|
Reduction method ('mean', 'sum', 'none') |
'mean'
|
Source code in src/training/losses.py
Functions¶
forward ¶
Compute loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pred |
Tensor
|
Predictions |
required |
target |
Tensor
|
Ground truth |
required |
t |
Optional[Tensor]
|
Optional time values for sequence data |
None
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Loss value |
Source code in src/training/losses.py
callbacks ¶
Training callbacks for monitoring and checkpointing.
Classes¶
EarlyStoppingCallback ¶
EarlyStoppingCallback(monitor: str = 'val_loss', patience: int = 20, min_delta: float = 0.0, mode: str = 'min')
Bases: Callback
Stop training when a monitored metric has stopped improving.
Initialize early stopping.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
monitor |
str
|
Metric to monitor |
'val_loss'
|
patience |
int
|
Number of epochs to wait |
20
|
min_delta |
float
|
Minimum change to qualify as improvement |
0.0
|
mode |
str
|
'min' or 'max' |
'min'
|
Source code in src/training/callbacks.py
Functions¶
on_epoch_end ¶
Check if training should stop.
Source code in src/training/callbacks.py
LRSchedulerCallback ¶
ModelCheckpointCallback ¶
ModelCheckpointCallback(save_path: Path, monitor: str = 'val_loss', mode: str = 'min', save_best_only: bool = True, save_fn: Optional[Callable] = None)
Bases: Callback
Save model checkpoints based on monitored metric.
Initialize checkpointing.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
save_path |
Path
|
Path to save checkpoint |
required |
monitor |
str
|
Metric to monitor |
'val_loss'
|
mode |
str
|
'min' or 'max' |
'min'
|
save_best_only |
bool
|
Only save when metric improves |
True
|
save_fn |
Optional[Callable]
|
Custom save function |
None
|
Source code in src/training/callbacks.py
Functions¶
on_epoch_end ¶
Save checkpoint if metric improved.