Source code for scripts.training.transmission_error_regression_module

"""Lightning regression module for normalized TE prediction workflows."""

from __future__ import annotations

# Import PyTorch Lightning Utilities
from lightning.pytorch import LightningModule

# Import PyTorch Utilities
import torch
import torch.nn as nn

# Import DataModule Utilities
from scripts.training.transmission_error_datamodule import NormalizationStatistics

[docs] class TransmissionErrorRegressionModule(LightningModule): """LightningModule that wraps TE backbones, normalization, and metrics."""
[docs] def __init__( self, regression_model: nn.Module, input_feature_dim: int, target_feature_dim: int, learning_rate: float = 1.0e-3, weight_decay: float = 1.0e-4, normalization_statistics: NormalizationStatistics | None = None, ) -> None: """Initialize the TE regression LightningModule. Args: regression_model: Backbone model operating on normalized inputs. input_feature_dim: Number of model input features. target_feature_dim: Number of regression targets. learning_rate: AdamW learning rate. weight_decay: AdamW weight decay. normalization_statistics: Optional normalization tensors loaded at construction time. """ super().__init__() # Validate Optimization Parameters assert input_feature_dim > 0, f"Input Feature Dim must be positive | {input_feature_dim}" assert target_feature_dim > 0, f"Target Feature Dim must be positive | {target_feature_dim}" assert learning_rate > 0.0, f"Learning Rate must be positive | {learning_rate}" assert weight_decay >= 0.0, f"Weight Decay must be non-negative | {weight_decay}" # Save Hyperparameters self.save_hyperparameters(ignore=["regression_model", "normalization_statistics"]) # Save Model And Loss self.regression_model = regression_model self.loss_function = nn.MSELoss() # Register Normalization Buffers self.register_buffer("input_feature_mean", torch.zeros(input_feature_dim, dtype=torch.float32)) self.register_buffer("input_feature_std", torch.ones(input_feature_dim, dtype=torch.float32)) self.register_buffer("target_mean", torch.zeros(target_feature_dim, dtype=torch.float32)) self.register_buffer("target_std", torch.ones(target_feature_dim, dtype=torch.float32)) # Initialize Normalization State self.normalization_statistics_initialized = False # Load Normalization Statistics If Available At Construction Time if normalization_statistics is not None: self.set_normalization_statistics(normalization_statistics)
[docs] def set_normalization_statistics(self, normalization_statistics: NormalizationStatistics) -> None: """Load normalization tensors into the module buffers. Args: normalization_statistics: Input and target statistics computed from the training split. """ # Validate Statistics Shapes assert normalization_statistics.input_feature_mean.shape == self.input_feature_mean.shape, ( f"Input Feature Mean shape mismatch | {tuple(normalization_statistics.input_feature_mean.shape)} " f"vs {tuple(self.input_feature_mean.shape)}" ) assert normalization_statistics.input_feature_std.shape == self.input_feature_std.shape, ( f"Input Feature Std shape mismatch | {tuple(normalization_statistics.input_feature_std.shape)} " f"vs {tuple(self.input_feature_std.shape)}" ) assert normalization_statistics.target_mean.shape == self.target_mean.shape, ( f"Target Mean shape mismatch | {tuple(normalization_statistics.target_mean.shape)} " f"vs {tuple(self.target_mean.shape)}" ) assert normalization_statistics.target_std.shape == self.target_std.shape, ( f"Target Std shape mismatch | {tuple(normalization_statistics.target_std.shape)} " f"vs {tuple(self.target_std.shape)}" ) # Copy Statistics Into Buffers self.input_feature_mean.copy_(normalization_statistics.input_feature_mean.float()) self.input_feature_std.copy_(torch.clamp(normalization_statistics.input_feature_std.float(), min=1.0e-8)) self.target_mean.copy_(normalization_statistics.target_mean.float()) self.target_std.copy_(torch.clamp(normalization_statistics.target_std.float(), min=1.0e-8)) # Mark Statistics As Ready self.normalization_statistics_initialized = True
[docs] def normalize_input_tensor(self, input_tensor: torch.Tensor) -> torch.Tensor: """Normalize model inputs with the registered training statistics.""" # Ensure Normalization Statistics Are Initialized Before Normalizing assert self.normalization_statistics_initialized, "Normalization Statistics must be initialized before training" return (input_tensor - self.input_feature_mean) / self.input_feature_std
[docs] def normalize_target_tensor(self, target_tensor: torch.Tensor) -> torch.Tensor: """Normalize regression targets with the registered statistics.""" # Ensure Normalization Statistics Are Initialized Before Normalizing assert self.normalization_statistics_initialized, "Normalization Statistics must be initialized before training" return (target_tensor - self.target_mean) / self.target_std
[docs] def denormalize_target_tensor(self, normalized_target_tensor: torch.Tensor) -> torch.Tensor: """Map normalized target predictions back to physical TE units.""" # Ensure Normalization Statistics Are Initialized Before Denormalizing assert self.normalization_statistics_initialized, "Normalization Statistics must be initialized before training" return (normalized_target_tensor * self.target_std) + self.target_mean
[docs] def forward(self, normalized_input_tensor: torch.Tensor) -> torch.Tensor: """Run the backbone on normalized inputs only.""" # Forward Pass Through Regression Model In Normalized Space return self.regression_model(normalized_input_tensor)
[docs] def forward_regression_model(self, input_tensor: torch.Tensor, normalized_input_tensor: torch.Tensor) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: """Run the backbone while supporting structured auxiliary outputs. Args: input_tensor: Raw input tensor before normalization-aware routing. normalized_input_tensor: Normalized model input tensor. Returns: tuple[torch.Tensor, dict[str, torch.Tensor]]: Prediction tensor in normalized space plus any auxiliary structured outputs emitted by the backbone. """ auxiliary_output_dictionary: dict[str, torch.Tensor] = {} # Prefer Context-Aware Model Forward When The Backbone Needs Raw And Normalized Views Together if hasattr(self.regression_model, "compute_auxiliary_output_dictionary"): # Compute Auxiliary Output Dictionary computed_auxiliary_output_dictionary = self.regression_model.compute_auxiliary_output_dictionary(input_tensor, normalized_input_tensor) assert isinstance(computed_auxiliary_output_dictionary, dict), "Auxiliary Output Dictionary must be a dictionary" assert "prediction_tensor" in computed_auxiliary_output_dictionary, "Auxiliary Output Dictionary must contain prediction_tensor" # Extract Prediction And Auxiliary Output Tensors prediction_tensor = computed_auxiliary_output_dictionary["prediction_tensor"] auxiliary_output_dictionary = { key: value for key, value in computed_auxiliary_output_dictionary.items() if key != "prediction_tensor" and isinstance(value, torch.Tensor) } return prediction_tensor, auxiliary_output_dictionary # Fallback To A Simpler Context-Aware Forward Signature if hasattr(self.regression_model, "forward_with_input_context"): return self.regression_model.forward_with_input_context(input_tensor, normalized_input_tensor), auxiliary_output_dictionary return self(normalized_input_tensor), auxiliary_output_dictionary
[docs] def compute_batch_outputs(self, batch_dictionary: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """Compute normalized loss terms and denormalized TE metrics. Args: batch_dictionary: Point-level batch emitted by the datamodule. Returns: dict[str, torch.Tensor]: Batch outputs including normalized predictions, denormalized predictions, and metric tensors. """ # Extract Batch Tensors input_tensor = batch_dictionary["input_tensor"].float() target_tensor = batch_dictionary["target_tensor"].float() # Normalize Input And Target normalized_input_tensor = self.normalize_input_tensor(input_tensor) normalized_target_tensor = self.normalize_target_tensor(target_tensor) # Forward Pass normalized_prediction_tensor, auxiliary_output_dictionary = self.forward_regression_model(input_tensor, normalized_input_tensor) # Compute Loss In Normalized Space loss = self.loss_function(normalized_prediction_tensor, normalized_target_tensor) # Denormalize Predictions For Interpretable Metrics prediction_tensor = self.denormalize_target_tensor(normalized_prediction_tensor) mae = torch.mean(torch.abs(prediction_tensor - target_tensor)) rmse = torch.sqrt(torch.mean(torch.square(prediction_tensor - target_tensor))) # Create Batch Output Dictionary batch_output_dictionary = { "input_tensor": input_tensor, "target_tensor": target_tensor, "normalized_input_tensor": normalized_input_tensor, "normalized_target_tensor": normalized_target_tensor, "normalized_prediction_tensor": normalized_prediction_tensor, "prediction_tensor": prediction_tensor, "loss": loss, "mae": mae, "rmse": rmse, } # Merge Optional Auxiliary Prediction Tensors Returned By Structured Models batch_output_dictionary.update(auxiliary_output_dictionary) return batch_output_dictionary
[docs] def compute_loss(self, batch_dictionary: dict[str, torch.Tensor], log_prefix: str) -> torch.Tensor: """Compute and log one split-specific loss bundle. Args: batch_dictionary: Point-level batch emitted by the datamodule. log_prefix: Prefix used for Lightning metric names such as `train`, `val`, or `test`. Returns: torch.Tensor: Scalar normalized-space MSE loss. """ # Compute Batch Outputs And Metrics batch_output_dictionary = self.compute_batch_outputs(batch_dictionary) input_tensor = batch_output_dictionary["input_tensor"] loss = batch_output_dictionary["loss"] mae = batch_output_dictionary["mae"] rmse = batch_output_dictionary["rmse"] # Log Metrics batch_size = int(input_tensor.shape[0]) self.log(f"{log_prefix}_loss", loss, on_step=False, on_epoch=True, prog_bar=True, batch_size=batch_size) self.log(f"{log_prefix}_mae", mae, on_step=False, on_epoch=True, prog_bar=(log_prefix != "train"), batch_size=batch_size) self.log(f"{log_prefix}_rmse", rmse, on_step=False, on_epoch=True, prog_bar=False, batch_size=batch_size) # Log Structured-Only Diagnostics When Available structured_prediction_tensor = batch_output_dictionary.get("structured_prediction_tensor") if isinstance(structured_prediction_tensor, torch.Tensor): structured_prediction_denormalized = self.denormalize_target_tensor(structured_prediction_tensor) structured_mae = torch.mean(torch.abs(structured_prediction_denormalized - batch_output_dictionary["target_tensor"])) structured_rmse = torch.sqrt(torch.mean(torch.square(structured_prediction_denormalized - batch_output_dictionary["target_tensor"]))) self.log(f"{log_prefix}_structured_mae", structured_mae, on_step=False, on_epoch=True, prog_bar=False, batch_size=batch_size) self.log(f"{log_prefix}_structured_rmse", structured_rmse, on_step=False, on_epoch=True, prog_bar=False, batch_size=batch_size) return loss
[docs] def training_step(self, batch_dictionary: dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor: """Run one Lightning training step and return the loss tensor.""" # Compute Loss And Metrics For Training Step return self.compute_loss(batch_dictionary, "train")
[docs] def validation_step(self, batch_dictionary: dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor: """Run one Lightning validation step and return the loss tensor.""" # Compute Loss And Metrics For Validation Step return self.compute_loss(batch_dictionary, "val")
[docs] def test_step(self, batch_dictionary: dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor: """Run one Lightning test step and return the loss tensor.""" # Compute Loss And Metrics For Test Step return self.compute_loss(batch_dictionary, "test")
[docs] def configure_optimizers(self): """Configure the AdamW optimizer for TE regression training.""" # Configure AdamW Optimizer return torch.optim.AdamW(self.parameters(), lr=float(self.hparams.learning_rate), weight_decay=float(self.hparams.weight_decay))