"""Lightning regression module for normalized TE prediction workflows."""
from __future__ import annotations
# Import PyTorch Lightning Utilities
from lightning.pytorch import LightningModule
# Import PyTorch Utilities
import torch
import torch.nn as nn
# Import DataModule Utilities
from scripts.training.transmission_error_datamodule import NormalizationStatistics
[docs]
class TransmissionErrorRegressionModule(LightningModule):
"""LightningModule that wraps TE backbones, normalization, and metrics."""
[docs]
def __init__(
self,
regression_model: nn.Module,
input_feature_dim: int,
target_feature_dim: int,
learning_rate: float = 1.0e-3,
weight_decay: float = 1.0e-4,
normalization_statistics: NormalizationStatistics | None = None,
) -> None:
"""Initialize the TE regression LightningModule.
Args:
regression_model: Backbone model operating on normalized inputs.
input_feature_dim: Number of model input features.
target_feature_dim: Number of regression targets.
learning_rate: AdamW learning rate.
weight_decay: AdamW weight decay.
normalization_statistics: Optional normalization tensors loaded at
construction time.
"""
super().__init__()
# Validate Optimization Parameters
assert input_feature_dim > 0, f"Input Feature Dim must be positive | {input_feature_dim}"
assert target_feature_dim > 0, f"Target Feature Dim must be positive | {target_feature_dim}"
assert learning_rate > 0.0, f"Learning Rate must be positive | {learning_rate}"
assert weight_decay >= 0.0, f"Weight Decay must be non-negative | {weight_decay}"
# Save Hyperparameters
self.save_hyperparameters(ignore=["regression_model", "normalization_statistics"])
# Save Model And Loss
self.regression_model = regression_model
self.loss_function = nn.MSELoss()
# Register Normalization Buffers
self.register_buffer("input_feature_mean", torch.zeros(input_feature_dim, dtype=torch.float32))
self.register_buffer("input_feature_std", torch.ones(input_feature_dim, dtype=torch.float32))
self.register_buffer("target_mean", torch.zeros(target_feature_dim, dtype=torch.float32))
self.register_buffer("target_std", torch.ones(target_feature_dim, dtype=torch.float32))
# Initialize Normalization State
self.normalization_statistics_initialized = False
# Load Normalization Statistics If Available At Construction Time
if normalization_statistics is not None: self.set_normalization_statistics(normalization_statistics)
[docs]
def set_normalization_statistics(self, normalization_statistics: NormalizationStatistics) -> None:
"""Load normalization tensors into the module buffers.
Args:
normalization_statistics: Input and target statistics computed from
the training split.
"""
# Validate Statistics Shapes
assert normalization_statistics.input_feature_mean.shape == self.input_feature_mean.shape, (
f"Input Feature Mean shape mismatch | {tuple(normalization_statistics.input_feature_mean.shape)} "
f"vs {tuple(self.input_feature_mean.shape)}"
)
assert normalization_statistics.input_feature_std.shape == self.input_feature_std.shape, (
f"Input Feature Std shape mismatch | {tuple(normalization_statistics.input_feature_std.shape)} "
f"vs {tuple(self.input_feature_std.shape)}"
)
assert normalization_statistics.target_mean.shape == self.target_mean.shape, (
f"Target Mean shape mismatch | {tuple(normalization_statistics.target_mean.shape)} "
f"vs {tuple(self.target_mean.shape)}"
)
assert normalization_statistics.target_std.shape == self.target_std.shape, (
f"Target Std shape mismatch | {tuple(normalization_statistics.target_std.shape)} "
f"vs {tuple(self.target_std.shape)}"
)
# Copy Statistics Into Buffers
self.input_feature_mean.copy_(normalization_statistics.input_feature_mean.float())
self.input_feature_std.copy_(torch.clamp(normalization_statistics.input_feature_std.float(), min=1.0e-8))
self.target_mean.copy_(normalization_statistics.target_mean.float())
self.target_std.copy_(torch.clamp(normalization_statistics.target_std.float(), min=1.0e-8))
# Mark Statistics As Ready
self.normalization_statistics_initialized = True
[docs]
def normalize_target_tensor(self, target_tensor: torch.Tensor) -> torch.Tensor:
"""Normalize regression targets with the registered statistics."""
# Ensure Normalization Statistics Are Initialized Before Normalizing
assert self.normalization_statistics_initialized, "Normalization Statistics must be initialized before training"
return (target_tensor - self.target_mean) / self.target_std
[docs]
def denormalize_target_tensor(self, normalized_target_tensor: torch.Tensor) -> torch.Tensor:
"""Map normalized target predictions back to physical TE units."""
# Ensure Normalization Statistics Are Initialized Before Denormalizing
assert self.normalization_statistics_initialized, "Normalization Statistics must be initialized before training"
return (normalized_target_tensor * self.target_std) + self.target_mean
[docs]
def forward(self, normalized_input_tensor: torch.Tensor) -> torch.Tensor:
"""Run the backbone on normalized inputs only."""
# Forward Pass Through Regression Model In Normalized Space
return self.regression_model(normalized_input_tensor)
[docs]
def forward_regression_model(self, input_tensor: torch.Tensor, normalized_input_tensor: torch.Tensor) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
"""Run the backbone while supporting structured auxiliary outputs.
Args:
input_tensor: Raw input tensor before normalization-aware routing.
normalized_input_tensor: Normalized model input tensor.
Returns:
tuple[torch.Tensor, dict[str, torch.Tensor]]: Prediction tensor in
normalized space plus any auxiliary structured outputs emitted by
the backbone.
"""
auxiliary_output_dictionary: dict[str, torch.Tensor] = {}
# Prefer Context-Aware Model Forward When The Backbone Needs Raw And Normalized Views Together
if hasattr(self.regression_model, "compute_auxiliary_output_dictionary"):
# Compute Auxiliary Output Dictionary
computed_auxiliary_output_dictionary = self.regression_model.compute_auxiliary_output_dictionary(input_tensor, normalized_input_tensor)
assert isinstance(computed_auxiliary_output_dictionary, dict), "Auxiliary Output Dictionary must be a dictionary"
assert "prediction_tensor" in computed_auxiliary_output_dictionary, "Auxiliary Output Dictionary must contain prediction_tensor"
# Extract Prediction And Auxiliary Output Tensors
prediction_tensor = computed_auxiliary_output_dictionary["prediction_tensor"]
auxiliary_output_dictionary = {
key: value
for key, value in computed_auxiliary_output_dictionary.items()
if key != "prediction_tensor" and isinstance(value, torch.Tensor)
}
return prediction_tensor, auxiliary_output_dictionary
# Fallback To A Simpler Context-Aware Forward Signature
if hasattr(self.regression_model, "forward_with_input_context"):
return self.regression_model.forward_with_input_context(input_tensor, normalized_input_tensor), auxiliary_output_dictionary
return self(normalized_input_tensor), auxiliary_output_dictionary
[docs]
def compute_batch_outputs(self, batch_dictionary: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""Compute normalized loss terms and denormalized TE metrics.
Args:
batch_dictionary: Point-level batch emitted by the datamodule.
Returns:
dict[str, torch.Tensor]: Batch outputs including normalized
predictions, denormalized predictions, and metric tensors.
"""
# Extract Batch Tensors
input_tensor = batch_dictionary["input_tensor"].float()
target_tensor = batch_dictionary["target_tensor"].float()
# Normalize Input And Target
normalized_input_tensor = self.normalize_input_tensor(input_tensor)
normalized_target_tensor = self.normalize_target_tensor(target_tensor)
# Forward Pass
normalized_prediction_tensor, auxiliary_output_dictionary = self.forward_regression_model(input_tensor, normalized_input_tensor)
# Compute Loss In Normalized Space
loss = self.loss_function(normalized_prediction_tensor, normalized_target_tensor)
# Denormalize Predictions For Interpretable Metrics
prediction_tensor = self.denormalize_target_tensor(normalized_prediction_tensor)
mae = torch.mean(torch.abs(prediction_tensor - target_tensor))
rmse = torch.sqrt(torch.mean(torch.square(prediction_tensor - target_tensor)))
# Create Batch Output Dictionary
batch_output_dictionary = {
"input_tensor": input_tensor,
"target_tensor": target_tensor,
"normalized_input_tensor": normalized_input_tensor,
"normalized_target_tensor": normalized_target_tensor,
"normalized_prediction_tensor": normalized_prediction_tensor,
"prediction_tensor": prediction_tensor,
"loss": loss,
"mae": mae,
"rmse": rmse,
}
# Merge Optional Auxiliary Prediction Tensors Returned By Structured Models
batch_output_dictionary.update(auxiliary_output_dictionary)
return batch_output_dictionary
[docs]
def compute_loss(self, batch_dictionary: dict[str, torch.Tensor], log_prefix: str) -> torch.Tensor:
"""Compute and log one split-specific loss bundle.
Args:
batch_dictionary: Point-level batch emitted by the datamodule.
log_prefix: Prefix used for Lightning metric names such as
`train`, `val`, or `test`.
Returns:
torch.Tensor: Scalar normalized-space MSE loss.
"""
# Compute Batch Outputs And Metrics
batch_output_dictionary = self.compute_batch_outputs(batch_dictionary)
input_tensor = batch_output_dictionary["input_tensor"]
loss = batch_output_dictionary["loss"]
mae = batch_output_dictionary["mae"]
rmse = batch_output_dictionary["rmse"]
# Log Metrics
batch_size = int(input_tensor.shape[0])
self.log(f"{log_prefix}_loss", loss, on_step=False, on_epoch=True, prog_bar=True, batch_size=batch_size)
self.log(f"{log_prefix}_mae", mae, on_step=False, on_epoch=True, prog_bar=(log_prefix != "train"), batch_size=batch_size)
self.log(f"{log_prefix}_rmse", rmse, on_step=False, on_epoch=True, prog_bar=False, batch_size=batch_size)
# Log Structured-Only Diagnostics When Available
structured_prediction_tensor = batch_output_dictionary.get("structured_prediction_tensor")
if isinstance(structured_prediction_tensor, torch.Tensor):
structured_prediction_denormalized = self.denormalize_target_tensor(structured_prediction_tensor)
structured_mae = torch.mean(torch.abs(structured_prediction_denormalized - batch_output_dictionary["target_tensor"]))
structured_rmse = torch.sqrt(torch.mean(torch.square(structured_prediction_denormalized - batch_output_dictionary["target_tensor"])))
self.log(f"{log_prefix}_structured_mae", structured_mae, on_step=False, on_epoch=True, prog_bar=False, batch_size=batch_size)
self.log(f"{log_prefix}_structured_rmse", structured_rmse, on_step=False, on_epoch=True, prog_bar=False, batch_size=batch_size)
return loss
[docs]
def training_step(self, batch_dictionary: dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor:
"""Run one Lightning training step and return the loss tensor."""
# Compute Loss And Metrics For Training Step
return self.compute_loss(batch_dictionary, "train")
[docs]
def validation_step(self, batch_dictionary: dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor:
"""Run one Lightning validation step and return the loss tensor."""
# Compute Loss And Metrics For Validation Step
return self.compute_loss(batch_dictionary, "val")
[docs]
def test_step(self, batch_dictionary: dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor:
"""Run one Lightning test step and return the loss tensor."""
# Compute Loss And Metrics For Test Step
return self.compute_loss(batch_dictionary, "test")