Transmission Error Regression Module
This page documents the Lightning regression wrapper that standardizes normalization, loss computation, metric logging, and optimizer configuration for TE models.
Lightning regression module for normalized TE prediction workflows.
- class scripts.training.transmission_error_regression_module.TransmissionErrorRegressionModule(regression_model, input_feature_dim, target_feature_dim, learning_rate=1.0e-3, weight_decay=1.0e-4, normalization_statistics=None)[source]
LightningModule that wraps TE backbones, normalization, and metrics.
- Parameters:
regression_model (nn.Module)
input_feature_dim (int)
target_feature_dim (int)
learning_rate (float)
weight_decay (float)
normalization_statistics (NormalizationStatistics | None)
- __init__(regression_model, input_feature_dim, target_feature_dim, learning_rate=1.0e-3, weight_decay=1.0e-4, normalization_statistics=None)[source]
Initialize the TE regression LightningModule.
- Parameters:
regression_model (Module) – Backbone model operating on normalized inputs.
input_feature_dim (int) – Number of model input features.
target_feature_dim (int) – Number of regression targets.
learning_rate (float) – AdamW learning rate.
weight_decay (float) – AdamW weight decay.
normalization_statistics (NormalizationStatistics | None) – Optional normalization tensors loaded at construction time.
- Return type:
None
- set_normalization_statistics(normalization_statistics)[source]
Load normalization tensors into the module buffers.
- Parameters:
normalization_statistics (NormalizationStatistics) – Input and target statistics computed from the training split.
- Return type:
None
- normalize_input_tensor(input_tensor)[source]
Normalize model inputs with the registered training statistics.
- Parameters:
input_tensor (Tensor)
- Return type:
Tensor
- normalize_target_tensor(target_tensor)[source]
Normalize regression targets with the registered statistics.
- Parameters:
target_tensor (Tensor)
- Return type:
Tensor
- denormalize_target_tensor(normalized_target_tensor)[source]
Map normalized target predictions back to physical TE units.
- Parameters:
normalized_target_tensor (Tensor)
- Return type:
Tensor
- forward(normalized_input_tensor)[source]
Run the backbone on normalized inputs only.
- Parameters:
normalized_input_tensor (Tensor)
- Return type:
Tensor
- forward_regression_model(input_tensor, normalized_input_tensor)[source]
Run the backbone while supporting structured auxiliary outputs.
- Parameters:
input_tensor (Tensor) – Raw input tensor before normalization-aware routing.
normalized_input_tensor (Tensor) – Normalized model input tensor.
- Returns:
Prediction tensor in normalized space plus any auxiliary structured outputs emitted by the backbone.
- Return type:
tuple[torch.Tensor, dict[str, torch.Tensor]]
- compute_batch_outputs(batch_dictionary)[source]
Compute normalized loss terms and denormalized TE metrics.
- Parameters:
batch_dictionary (dict[str, Tensor]) – Point-level batch emitted by the datamodule.
- Returns:
Batch outputs including normalized predictions, denormalized predictions, and metric tensors.
- Return type:
dict[str, torch.Tensor]
- compute_loss(batch_dictionary, log_prefix)[source]
Compute and log one split-specific loss bundle.
- Parameters:
batch_dictionary (dict[str, Tensor]) – Point-level batch emitted by the datamodule.
log_prefix (str) – Prefix used for Lightning metric names such as train, val, or test.
- Returns:
Scalar normalized-space MSE loss.
- Return type:
torch.Tensor
- training_step(batch_dictionary, batch_idx)[source]
Run one Lightning training step and return the loss tensor.
- Parameters:
batch_dictionary (dict[str, Tensor])
batch_idx (int)
- Return type:
Tensor
- validation_step(batch_dictionary, batch_idx)[source]
Run one Lightning validation step and return the loss tensor.
- Parameters:
batch_dictionary (dict[str, Tensor])
batch_idx (int)
- Return type:
Tensor