"""Lightning regression module for normalized TE prediction workflows."""
from __future__ import annotations
# Import PyTorch Lightning Utilities
from lightning.pytorch import LightningModule
# Import PyTorch Utilities
import torch
import torch.nn as nn
import torch.nn.functional as torch_functional
# Import DataModule Utilities
from scripts.training.transmission_error_datamodule import NormalizationStatistics
[docs]
class TransmissionErrorRegressionModule(LightningModule):
"""LightningModule that wraps TE backbones, normalization, and metrics."""
[docs]
def __init__(
self,
regression_model: nn.Module,
input_feature_dim: int,
target_feature_dim: int,
learning_rate: float = 1.0e-3,
weight_decay: float = 1.0e-4,
normalization_statistics: NormalizationStatistics | None = None,
loss_configuration: dict | None = None,
) -> None:
"""Initialize the TE regression LightningModule.
Args:
regression_model: Backbone model operating on normalized inputs.
input_feature_dim: Number of model input features.
target_feature_dim: Number of regression targets.
learning_rate: AdamW learning rate.
weight_decay: AdamW weight decay.
normalization_statistics: Optional normalization tensors loaded at
construction time.
loss_configuration: Optional composite loss configuration used by
curve-aware campaign branches.
"""
super().__init__()
# Validate Optimization Parameters
assert input_feature_dim > 0, f"Input Feature Dim must be positive | {input_feature_dim}"
assert target_feature_dim > 0, f"Target Feature Dim must be positive | {target_feature_dim}"
assert learning_rate > 0.0, f"Learning Rate must be positive | {learning_rate}"
assert weight_decay >= 0.0, f"Weight Decay must be non-negative | {weight_decay}"
# Save Hyperparameters
self.save_hyperparameters(ignore=["regression_model", "normalization_statistics"])
# Save Model And Loss
self.regression_model = regression_model
self.loss_configuration = self._normalize_loss_configuration(loss_configuration or {})
# Register Normalization Buffers
self.register_buffer("input_feature_mean", torch.zeros(input_feature_dim, dtype=torch.float32))
self.register_buffer("input_feature_std", torch.ones(input_feature_dim, dtype=torch.float32))
self.register_buffer("target_mean", torch.zeros(target_feature_dim, dtype=torch.float32))
self.register_buffer("target_std", torch.ones(target_feature_dim, dtype=torch.float32))
# Initialize Normalization State
self.normalization_statistics_initialized = False
# Load Normalization Statistics If Available At Construction Time
if normalization_statistics is not None: self.set_normalization_statistics(normalization_statistics)
def _normalize_loss_configuration(self, loss_configuration: dict) -> dict[str, object]:
"""Normalize optional curve-aware loss settings."""
# Resolve Loss Profile And Weights
loss_profile = str(loss_configuration.get("profile", "pointwise_control")).strip().lower()
pointwise_loss_name = str(loss_configuration.get("pointwise_loss", "mse")).strip().lower()
weight_dictionary = dict(loss_configuration.get("weights", {}))
harmonic_index_list = list(loss_configuration.get("harmonic_index_list", []))
huber_delta = float(loss_configuration.get("huber_delta", loss_configuration.get("smooth_l1_beta", 1.0)))
quantile_level_list = [
float(quantile_level)
for quantile_level in loss_configuration.get("quantile_level_list", [0.1, 0.5, 0.9])
]
deterministic_output_index = int(loss_configuration.get("deterministic_output_index", 0))
gaussian_log_sigma_min = float(loss_configuration.get("gaussian_log_sigma_min", -7.0))
gaussian_log_sigma_max = float(loss_configuration.get("gaussian_log_sigma_max", 5.0))
gaussian_sigma_min = float(loss_configuration.get("gaussian_sigma_min", 1.0e-4))
mixture_component_count = int(loss_configuration.get("mixture_component_count", 1))
mixture_log_sigma_min = float(loss_configuration.get("mixture_log_sigma_min", -7.0))
mixture_log_sigma_max = float(loss_configuration.get("mixture_log_sigma_max", 5.0))
mixture_sigma_min = float(loss_configuration.get("mixture_sigma_min", 1.0e-4))
quantile_crossing_penalty_weight = float(loss_configuration.get("quantile_crossing_penalty_weight", 0.0))
assert deterministic_output_index >= 0, f"Deterministic Output Index must be non-negative | {deterministic_output_index}"
assert gaussian_log_sigma_min < gaussian_log_sigma_max, (
f"Gaussian log-sigma bounds must be ordered | {gaussian_log_sigma_min} vs {gaussian_log_sigma_max}"
)
assert gaussian_sigma_min > 0.0, f"Gaussian Sigma Min must be positive | {gaussian_sigma_min}"
assert mixture_component_count > 0, f"Mixture Component Count must be positive | {mixture_component_count}"
assert mixture_log_sigma_min < mixture_log_sigma_max, (
f"Mixture log-sigma bounds must be ordered | {mixture_log_sigma_min} vs {mixture_log_sigma_max}"
)
assert mixture_sigma_min > 0.0, f"Mixture Sigma Min must be positive | {mixture_sigma_min}"
assert all(0.0 < quantile_level < 1.0 for quantile_level in quantile_level_list), (
f"Quantile levels must be inside (0, 1) | {quantile_level_list}"
)
return {
"profile": loss_profile,
"pointwise_loss": pointwise_loss_name,
"huber_delta": huber_delta,
"quantile_level_list": quantile_level_list,
"deterministic_output_index": deterministic_output_index,
"gaussian_log_sigma_min": gaussian_log_sigma_min,
"gaussian_log_sigma_max": gaussian_log_sigma_max,
"gaussian_sigma_min": gaussian_sigma_min,
"mixture_component_count": mixture_component_count,
"mixture_log_sigma_min": mixture_log_sigma_min,
"mixture_log_sigma_max": mixture_log_sigma_max,
"mixture_sigma_min": mixture_sigma_min,
"quantile_crossing_penalty_weight": quantile_crossing_penalty_weight,
"point_weight": float(weight_dictionary.get("point", 1.0)),
"centered_weight": float(weight_dictionary.get("centered", 0.0)),
"offset_weight": float(weight_dictionary.get("offset", 0.0)),
"amplitude_weight": float(weight_dictionary.get("amplitude", 0.0)),
"harmonic_weight": float(weight_dictionary.get("harmonic", 0.0)),
"harmonic_index_list": [int(harmonic_index) for harmonic_index in harmonic_index_list if int(harmonic_index) > 0],
}
[docs]
def compute_pointwise_prediction_loss(self, prediction_tensor: torch.Tensor, target_tensor: torch.Tensor) -> torch.Tensor:
"""Compute the configured normalized-space pointwise regression loss."""
# Resolve Pointwise Loss Name
pointwise_loss_name = str(self.loss_configuration["pointwise_loss"])
error_tensor = prediction_tensor - target_tensor
# Dispatch Supported Losses
if pointwise_loss_name in ["mse", "l2", "mean_squared_error"]:
return torch.mean(torch.square(error_tensor))
if pointwise_loss_name in ["mae", "l1", "mean_absolute_error"]:
return torch.mean(torch.abs(error_tensor))
if pointwise_loss_name in ["smooth_l1", "huber"]:
huber_delta = float(self.loss_configuration["huber_delta"])
assert huber_delta > 0.0, f"Huber Delta must be positive | {huber_delta}"
return torch_functional.huber_loss(prediction_tensor, target_tensor, reduction="mean", delta=huber_delta)
if pointwise_loss_name in ["log_cosh", "logcosh"]:
absolute_error_tensor = torch.abs(error_tensor)
return torch.mean(absolute_error_tensor + torch_functional.softplus(-2.0 * absolute_error_tensor) - torch.log(torch.tensor(2.0, device=absolute_error_tensor.device)))
if pointwise_loss_name in ["quantile_pinball", "pinball", "quantile"]:
return self.compute_quantile_pinball_loss(prediction_tensor, target_tensor)
if pointwise_loss_name in ["gaussian_nll", "gaussian_negative_log_likelihood"]:
return self.compute_gaussian_negative_log_likelihood_loss(prediction_tensor, target_tensor)
if pointwise_loss_name in ["mixture_density_nll", "mdn_nll", "mixture_gaussian_nll"]:
return self.compute_mixture_density_negative_log_likelihood_loss(prediction_tensor, target_tensor)
raise ValueError(f"Unsupported pointwise loss | {pointwise_loss_name}")
[docs]
def compute_quantile_pinball_loss(self, prediction_tensor: torch.Tensor, target_tensor: torch.Tensor) -> torch.Tensor:
"""Compute multi-quantile pinball loss in normalized target space."""
quantile_level_list = list(self.loss_configuration["quantile_level_list"])
assert prediction_tensor.shape[-1] == len(quantile_level_list), (
f"Quantile output size mismatch | {prediction_tensor.shape[-1]} vs {len(quantile_level_list)}"
)
quantile_level_tensor = torch.as_tensor(
quantile_level_list,
device=prediction_tensor.device,
dtype=prediction_tensor.dtype,
).reshape(1, -1)
error_tensor = target_tensor - prediction_tensor
pinball_loss_tensor = torch.maximum(
quantile_level_tensor * error_tensor,
(quantile_level_tensor - 1.0) * error_tensor,
)
pinball_loss = torch.mean(pinball_loss_tensor)
crossing_penalty_weight = float(self.loss_configuration["quantile_crossing_penalty_weight"])
if crossing_penalty_weight > 0.0 and prediction_tensor.shape[-1] > 1:
quantile_delta_tensor = prediction_tensor[:, :-1] - prediction_tensor[:, 1:]
crossing_penalty = torch.mean(torch_functional.relu(quantile_delta_tensor))
pinball_loss = pinball_loss + crossing_penalty_weight * crossing_penalty
return pinball_loss
[docs]
def compute_gaussian_negative_log_likelihood_loss(
self,
prediction_tensor: torch.Tensor,
target_tensor: torch.Tensor,
) -> torch.Tensor:
"""Compute guarded Gaussian NLL in normalized target space."""
assert prediction_tensor.shape[-1] == 2, f"Gaussian output must contain mu and log_sigma | {prediction_tensor.shape[-1]}"
mu_tensor = prediction_tensor[:, 0:1]
raw_log_sigma_tensor = prediction_tensor[:, 1:2]
log_sigma_tensor = torch.clamp(
raw_log_sigma_tensor,
min=float(self.loss_configuration["gaussian_log_sigma_min"]),
max=float(self.loss_configuration["gaussian_log_sigma_max"]),
)
sigma_tensor = torch.clamp(torch.exp(log_sigma_tensor), min=float(self.loss_configuration["gaussian_sigma_min"]))
normalized_residual_tensor = (target_tensor - mu_tensor) / sigma_tensor
log_two_pi_tensor = torch.log(torch.as_tensor(2.0 * torch.pi, device=prediction_tensor.device, dtype=prediction_tensor.dtype))
return torch.mean(0.5 * torch.square(normalized_residual_tensor) + log_sigma_tensor + 0.5 * log_two_pi_tensor)
[docs]
def split_mixture_density_output_tensor(
self,
prediction_tensor: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Split MDN output into logits, means, and guarded log-sigmas."""
component_count = int(self.loss_configuration["mixture_component_count"])
expected_output_size = component_count * 3
assert prediction_tensor.shape[-1] == expected_output_size, (
f"MDN output size mismatch | {prediction_tensor.shape[-1]} vs {expected_output_size}"
)
mixture_logit_tensor = prediction_tensor[:, 0:component_count]
component_mean_tensor = prediction_tensor[:, component_count:2 * component_count]
raw_log_sigma_tensor = prediction_tensor[:, 2 * component_count:3 * component_count]
log_sigma_tensor = torch.clamp(
raw_log_sigma_tensor,
min=float(self.loss_configuration["mixture_log_sigma_min"]),
max=float(self.loss_configuration["mixture_log_sigma_max"]),
)
return mixture_logit_tensor, component_mean_tensor, log_sigma_tensor
[docs]
def compute_mixture_density_negative_log_likelihood_loss(
self,
prediction_tensor: torch.Tensor,
target_tensor: torch.Tensor,
) -> torch.Tensor:
"""Compute stable Gaussian-mixture NLL in normalized target space."""
mixture_logit_tensor, component_mean_tensor, log_sigma_tensor = self.split_mixture_density_output_tensor(
prediction_tensor
)
sigma_tensor = torch.clamp(
torch.exp(log_sigma_tensor),
min=float(self.loss_configuration["mixture_sigma_min"]),
)
normalized_residual_tensor = (target_tensor - component_mean_tensor) / sigma_tensor
log_two_pi_tensor = torch.log(
torch.as_tensor(2.0 * torch.pi, device=prediction_tensor.device, dtype=prediction_tensor.dtype)
)
component_log_probability_tensor = (
-0.5 * torch.square(normalized_residual_tensor)
- log_sigma_tensor
- 0.5 * log_two_pi_tensor
)
mixture_log_weight_tensor = torch_functional.log_softmax(mixture_logit_tensor, dim=-1)
sample_log_probability_tensor = torch.logsumexp(
mixture_log_weight_tensor + component_log_probability_tensor,
dim=-1,
)
return torch.mean(-sample_log_probability_tensor)
[docs]
def compute_mixture_expectation_tensor(self, model_output_tensor: torch.Tensor) -> torch.Tensor:
"""Compute deterministic MDN playback as the mixture expectation."""
mixture_logit_tensor, component_mean_tensor, _log_sigma_tensor = self.split_mixture_density_output_tensor(
model_output_tensor
)
mixture_weight_tensor = torch.softmax(mixture_logit_tensor, dim=-1)
return torch.sum(mixture_weight_tensor * component_mean_tensor, dim=-1, keepdim=True)
[docs]
def compute_probabilistic_metric_dictionary(
self,
model_output_tensor: torch.Tensor,
normalized_target_tensor: torch.Tensor,
) -> dict[str, torch.Tensor]:
"""Compute optional uncertainty diagnostics from raw model outputs."""
pointwise_loss_name = str(self.loss_configuration["pointwise_loss"])
metric_dictionary: dict[str, torch.Tensor] = {}
if pointwise_loss_name in ["quantile_pinball", "pinball", "quantile"]:
assert model_output_tensor.shape[-1] >= 3, "Quantile diagnostics require lower, median, and upper channels"
lower_quantile_tensor = model_output_tensor[:, 0:1]
upper_quantile_tensor = model_output_tensor[:, -1:]
interval_width_tensor = self.denormalize_target_tensor(upper_quantile_tensor) - self.denormalize_target_tensor(lower_quantile_tensor)
coverage_tensor = (
(normalized_target_tensor >= lower_quantile_tensor)
& (normalized_target_tensor <= upper_quantile_tensor)
).float()
crossing_tensor = (lower_quantile_tensor > upper_quantile_tensor).float()
metric_dictionary["interval_coverage"] = torch.mean(coverage_tensor)
metric_dictionary["interval_width"] = torch.mean(torch.abs(interval_width_tensor))
metric_dictionary["quantile_crossing_rate"] = torch.mean(crossing_tensor)
if pointwise_loss_name in ["gaussian_nll", "gaussian_negative_log_likelihood"]:
mu_tensor = model_output_tensor[:, 0:1]
log_sigma_tensor = torch.clamp(
model_output_tensor[:, 1:2],
min=float(self.loss_configuration["gaussian_log_sigma_min"]),
max=float(self.loss_configuration["gaussian_log_sigma_max"]),
)
sigma_tensor = torch.clamp(torch.exp(log_sigma_tensor), min=float(self.loss_configuration["gaussian_sigma_min"]))
z80_value = 1.2815515655446004
lower_tensor = mu_tensor - z80_value * sigma_tensor
upper_tensor = mu_tensor + z80_value * sigma_tensor
interval_width_tensor = self.denormalize_target_tensor(upper_tensor) - self.denormalize_target_tensor(lower_tensor)
coverage_tensor = ((normalized_target_tensor >= lower_tensor) & (normalized_target_tensor <= upper_tensor)).float()
metric_dictionary["interval_coverage"] = torch.mean(coverage_tensor)
metric_dictionary["interval_width"] = torch.mean(torch.abs(interval_width_tensor))
metric_dictionary["mean_sigma"] = torch.mean(sigma_tensor * self.target_std)
if pointwise_loss_name in ["mixture_density_nll", "mdn_nll", "mixture_gaussian_nll"]:
mixture_logit_tensor, component_mean_tensor, log_sigma_tensor = self.split_mixture_density_output_tensor(
model_output_tensor
)
mixture_weight_tensor = torch.softmax(mixture_logit_tensor, dim=-1)
log_mixture_weight_tensor = torch_functional.log_softmax(mixture_logit_tensor, dim=-1)
sigma_tensor = torch.clamp(
torch.exp(log_sigma_tensor),
min=float(self.loss_configuration["mixture_sigma_min"]),
)
entropy_tensor = -torch.sum(mixture_weight_tensor * log_mixture_weight_tensor, dim=-1)
effective_component_tensor = torch.exp(entropy_tensor)
component_mean_physical_tensor = self.denormalize_target_tensor(component_mean_tensor)
component_scale_physical_tensor = sigma_tensor * self.target_std
component_separation_tensor = (
torch.max(component_mean_physical_tensor, dim=-1).values
- torch.min(component_mean_physical_tensor, dim=-1).values
)
metric_dictionary["mixture_weight_entropy"] = torch.mean(entropy_tensor)
metric_dictionary["mixture_effective_components"] = torch.mean(effective_component_tensor)
metric_dictionary["mixture_mean_sigma"] = torch.mean(component_scale_physical_tensor)
metric_dictionary["mixture_component_separation"] = torch.mean(torch.abs(component_separation_tensor))
return metric_dictionary
[docs]
def compute_curve_aware_loss_dictionary(
self,
batch_dictionary: dict[str, torch.Tensor],
batch_output_dictionary: dict[str, torch.Tensor],
pointwise_loss: torch.Tensor,
) -> dict[str, torch.Tensor]:
"""Compute optional curve-aware loss terms from collated curve groups."""
# Extract Configured Weights
point_weight = float(self.loss_configuration["point_weight"])
centered_weight = float(self.loss_configuration["centered_weight"])
offset_weight = float(self.loss_configuration["offset_weight"])
amplitude_weight = float(self.loss_configuration["amplitude_weight"])
harmonic_weight = float(self.loss_configuration["harmonic_weight"])
harmonic_index_list = list(self.loss_configuration["harmonic_index_list"])
# Initialize Loss Terms
zero_loss = pointwise_loss * 0.0
centered_shape_loss = zero_loss
curve_offset_loss = zero_loss
curve_amplitude_loss = zero_loss
sparse_harmonic_shape_loss = zero_loss
curve_count_tensor = batch_dictionary.get("curve_count", None)
count_per_curve_tensor = batch_dictionary.get("sequence_count_per_curve", batch_dictionary.get("point_count_per_curve", None))
can_compute_curve_terms = isinstance(count_per_curve_tensor, torch.Tensor) and int(count_per_curve_tensor.numel()) > 0
# Compute Curve Terms Only When The Collate Function Preserved Per-Curve Counts
if can_compute_curve_terms:
prediction_tensor = batch_output_dictionary["prediction_tensor"]
target_tensor = batch_output_dictionary["target_tensor"]
angular_position_tensor = batch_dictionary["angular_position_deg"].to(prediction_tensor.device).float()
count_per_curve_tensor = count_per_curve_tensor.to(prediction_tensor.device).long()
centered_loss_list: list[torch.Tensor] = []
offset_loss_list: list[torch.Tensor] = []
amplitude_loss_list: list[torch.Tensor] = []
harmonic_loss_list: list[torch.Tensor] = []
start_index = 0
# Walk Contiguous Per-Curve Segments In The Collated Batch
for curve_count in count_per_curve_tensor.tolist():
end_index = start_index + int(curve_count)
curve_prediction_tensor = prediction_tensor[start_index:end_index]
curve_target_tensor = target_tensor[start_index:end_index]
curve_angle_tensor = angular_position_tensor[start_index:end_index]
start_index = end_index
if int(curve_prediction_tensor.numel()) == 0:
continue
prediction_mean_tensor = torch.mean(curve_prediction_tensor, dim=0, keepdim=True)
target_mean_tensor = torch.mean(curve_target_tensor, dim=0, keepdim=True)
centered_prediction_tensor = curve_prediction_tensor - prediction_mean_tensor
centered_target_tensor = curve_target_tensor - target_mean_tensor
centered_loss_list.append(torch.mean(torch.square(centered_prediction_tensor - centered_target_tensor)))
offset_loss_list.append(torch.mean(torch.square(prediction_mean_tensor - target_mean_tensor)))
prediction_amplitude_tensor = torch.max(curve_prediction_tensor, dim=0).values - torch.min(curve_prediction_tensor, dim=0).values
target_amplitude_tensor = torch.max(curve_target_tensor, dim=0).values - torch.min(curve_target_tensor, dim=0).values
amplitude_loss_list.append(torch.mean(torch.square(prediction_amplitude_tensor - target_amplitude_tensor)))
if harmonic_index_list:
harmonic_loss_list.append(
self.compute_sparse_harmonic_shape_loss(
curve_angle_tensor,
centered_prediction_tensor,
centered_target_tensor,
harmonic_index_list,
)
)
if centered_loss_list:
centered_shape_loss = torch.stack(centered_loss_list).mean()
curve_offset_loss = torch.stack(offset_loss_list).mean()
curve_amplitude_loss = torch.stack(amplitude_loss_list).mean()
if harmonic_loss_list:
sparse_harmonic_shape_loss = torch.stack(harmonic_loss_list).mean()
# Combine Weighted Terms
total_loss = (
point_weight * pointwise_loss
+ centered_weight * centered_shape_loss
+ offset_weight * curve_offset_loss
+ amplitude_weight * curve_amplitude_loss
+ harmonic_weight * sparse_harmonic_shape_loss
)
return {
"loss": total_loss,
"pointwise_loss": pointwise_loss,
"centered_curve_shape_loss": centered_shape_loss,
"curve_offset_loss": curve_offset_loss,
"curve_amplitude_loss": curve_amplitude_loss,
"sparse_harmonic_shape_loss": sparse_harmonic_shape_loss,
"curve_aware_curve_count": torch.as_tensor(
int(curve_count_tensor) if isinstance(curve_count_tensor, int) else int(count_per_curve_tensor.numel()) if can_compute_curve_terms else 0,
device=pointwise_loss.device,
dtype=torch.float32,
),
}
[docs]
def compute_sparse_harmonic_shape_loss(
self,
angular_position_tensor: torch.Tensor,
centered_prediction_tensor: torch.Tensor,
centered_target_tensor: torch.Tensor,
harmonic_index_list: list[int],
) -> torch.Tensor:
"""Compare centered prediction and truth on sparse sine/cosine terms."""
# Prepare Angle Tensor In Radians
angle_radian_tensor = torch.deg2rad(angular_position_tensor.reshape(-1, 1))
prediction_vector = centered_prediction_tensor.reshape(-1, 1)
target_vector = centered_target_tensor.reshape(-1, 1)
harmonic_loss_list: list[torch.Tensor] = []
# Compare Projection Coefficients For Each Harmonic
for harmonic_index in harmonic_index_list:
sine_basis_tensor = torch.sin(float(harmonic_index) * angle_radian_tensor)
cosine_basis_tensor = torch.cos(float(harmonic_index) * angle_radian_tensor)
prediction_sine_coefficient = torch.mean(prediction_vector * sine_basis_tensor, dim=0)
target_sine_coefficient = torch.mean(target_vector * sine_basis_tensor, dim=0)
prediction_cosine_coefficient = torch.mean(prediction_vector * cosine_basis_tensor, dim=0)
target_cosine_coefficient = torch.mean(target_vector * cosine_basis_tensor, dim=0)
harmonic_loss_list.append(torch.mean(torch.square(prediction_sine_coefficient - target_sine_coefficient)))
harmonic_loss_list.append(torch.mean(torch.square(prediction_cosine_coefficient - target_cosine_coefficient)))
return torch.stack(harmonic_loss_list).mean()
[docs]
def set_normalization_statistics(self, normalization_statistics: NormalizationStatistics) -> None:
"""Load normalization tensors into the module buffers.
Args:
normalization_statistics: Input and target statistics computed from
the training split.
"""
# Validate Statistics Shapes
assert normalization_statistics.input_feature_mean.shape == self.input_feature_mean.shape, (
f"Input Feature Mean shape mismatch | {tuple(normalization_statistics.input_feature_mean.shape)} "
f"vs {tuple(self.input_feature_mean.shape)}"
)
assert normalization_statistics.input_feature_std.shape == self.input_feature_std.shape, (
f"Input Feature Std shape mismatch | {tuple(normalization_statistics.input_feature_std.shape)} "
f"vs {tuple(self.input_feature_std.shape)}"
)
assert normalization_statistics.target_mean.shape == self.target_mean.shape, (
f"Target Mean shape mismatch | {tuple(normalization_statistics.target_mean.shape)} "
f"vs {tuple(self.target_mean.shape)}"
)
assert normalization_statistics.target_std.shape == self.target_std.shape, (
f"Target Std shape mismatch | {tuple(normalization_statistics.target_std.shape)} "
f"vs {tuple(self.target_std.shape)}"
)
# Copy Statistics Into Buffers
self.input_feature_mean.copy_(normalization_statistics.input_feature_mean.float())
self.input_feature_std.copy_(torch.clamp(normalization_statistics.input_feature_std.float(), min=1.0e-8))
self.target_mean.copy_(normalization_statistics.target_mean.float())
self.target_std.copy_(torch.clamp(normalization_statistics.target_std.float(), min=1.0e-8))
# Mark Statistics As Ready
self.normalization_statistics_initialized = True
[docs]
def normalize_target_tensor(self, target_tensor: torch.Tensor) -> torch.Tensor:
"""Normalize regression targets with the registered statistics."""
# Ensure Normalization Statistics Are Initialized Before Normalizing
assert self.normalization_statistics_initialized, "Normalization Statistics must be initialized before training"
return (target_tensor - self.target_mean) / self.target_std
[docs]
def denormalize_target_tensor(self, normalized_target_tensor: torch.Tensor) -> torch.Tensor:
"""Map normalized target predictions back to physical TE units."""
# Ensure Normalization Statistics Are Initialized Before Denormalizing
assert self.normalization_statistics_initialized, "Normalization Statistics must be initialized before training"
return (normalized_target_tensor * self.target_std) + self.target_mean
[docs]
def forward(self, normalized_input_tensor: torch.Tensor) -> torch.Tensor:
"""Run the backbone on normalized inputs only."""
# Forward Pass Through Regression Model In Normalized Space
return self.regression_model(normalized_input_tensor)
[docs]
def forward_regression_model(self, input_tensor: torch.Tensor, normalized_input_tensor: torch.Tensor) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
"""Run the backbone while supporting structured auxiliary outputs.
Args:
input_tensor: Raw input tensor before normalization-aware routing.
normalized_input_tensor: Normalized model input tensor.
Returns:
tuple[torch.Tensor, dict[str, torch.Tensor]]: Prediction tensor in
normalized space plus any auxiliary structured outputs emitted by
the backbone.
"""
auxiliary_output_dictionary: dict[str, torch.Tensor] = {}
# Prefer Context-Aware Model Forward When The Backbone Needs Raw And Normalized Views Together
if hasattr(self.regression_model, "compute_auxiliary_output_dictionary"):
# Compute Auxiliary Output Dictionary
computed_auxiliary_output_dictionary = self.regression_model.compute_auxiliary_output_dictionary(input_tensor, normalized_input_tensor)
assert isinstance(computed_auxiliary_output_dictionary, dict), "Auxiliary Output Dictionary must be a dictionary"
assert "prediction_tensor" in computed_auxiliary_output_dictionary, "Auxiliary Output Dictionary must contain prediction_tensor"
# Extract Prediction And Auxiliary Output Tensors
prediction_tensor = computed_auxiliary_output_dictionary["prediction_tensor"]
auxiliary_output_dictionary = {
key: value
for key, value in computed_auxiliary_output_dictionary.items()
if key != "prediction_tensor" and isinstance(value, torch.Tensor)
}
return prediction_tensor, auxiliary_output_dictionary
# Fallback To A Simpler Context-Aware Forward Signature
if hasattr(self.regression_model, "forward_with_input_context"):
return self.regression_model.forward_with_input_context(input_tensor, normalized_input_tensor), auxiliary_output_dictionary
return self(normalized_input_tensor), auxiliary_output_dictionary
[docs]
def compute_batch_outputs(self, batch_dictionary: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""Compute normalized loss terms and denormalized TE metrics.
Args:
batch_dictionary: Point-level batch emitted by the datamodule.
Returns:
dict[str, torch.Tensor]: Batch outputs including normalized
predictions, denormalized predictions, and metric tensors.
"""
# Extract Batch Tensors
input_tensor = batch_dictionary["input_tensor"].float()
target_tensor = batch_dictionary["target_tensor"].float()
# Normalize Input And Target
normalized_input_tensor = self.normalize_input_tensor(input_tensor)
normalized_target_tensor = self.normalize_target_tensor(target_tensor)
# Forward Pass
normalized_model_output_tensor, auxiliary_output_dictionary = self.forward_regression_model(input_tensor, normalized_input_tensor)
normalized_prediction_tensor = self.extract_deterministic_prediction_tensor(normalized_model_output_tensor)
# Denormalize Predictions For Interpretable Metrics
prediction_tensor = self.denormalize_target_tensor(normalized_prediction_tensor)
mae = torch.mean(torch.abs(prediction_tensor - target_tensor))
rmse = torch.sqrt(torch.mean(torch.square(prediction_tensor - target_tensor)))
# Compute Pointwise And Optional Curve-Aware Loss Terms
pointwise_loss = self.compute_pointwise_prediction_loss(normalized_model_output_tensor, normalized_target_tensor)
loss_dictionary = self.compute_curve_aware_loss_dictionary(
batch_dictionary=batch_dictionary,
batch_output_dictionary={
"prediction_tensor": normalized_prediction_tensor,
"target_tensor": normalized_target_tensor,
},
pointwise_loss=pointwise_loss,
)
# Create Batch Output Dictionary
batch_output_dictionary = {
"input_tensor": input_tensor,
"target_tensor": target_tensor,
"normalized_input_tensor": normalized_input_tensor,
"normalized_target_tensor": normalized_target_tensor,
"normalized_model_output_tensor": normalized_model_output_tensor,
"normalized_prediction_tensor": normalized_prediction_tensor,
"prediction_tensor": prediction_tensor,
"mae": mae,
"rmse": rmse,
}
# Merge Optional Auxiliary Prediction Tensors Returned By Structured Models
batch_output_dictionary.update(auxiliary_output_dictionary)
batch_output_dictionary.update(loss_dictionary)
batch_output_dictionary.update(self.compute_probabilistic_metric_dictionary(normalized_model_output_tensor, normalized_target_tensor))
return batch_output_dictionary
[docs]
def compute_loss(self, batch_dictionary: dict[str, torch.Tensor], log_prefix: str) -> torch.Tensor:
"""Compute and log one split-specific loss bundle.
Args:
batch_dictionary: Point-level batch emitted by the datamodule.
log_prefix: Prefix used for Lightning metric names such as
`train`, `val`, or `test`.
Returns:
torch.Tensor: Scalar normalized-space MSE loss.
"""
# Compute Batch Outputs And Metrics
batch_output_dictionary = self.compute_batch_outputs(batch_dictionary)
input_tensor = batch_output_dictionary["input_tensor"]
loss = batch_output_dictionary["loss"]
mae = batch_output_dictionary["mae"]
rmse = batch_output_dictionary["rmse"]
# Log Metrics
batch_size = int(input_tensor.shape[0])
self.log(f"{log_prefix}_loss", loss, on_step=False, on_epoch=True, prog_bar=True, batch_size=batch_size)
self.log(f"{log_prefix}_mae", mae, on_step=False, on_epoch=True, prog_bar=(log_prefix != "train"), batch_size=batch_size)
self.log(f"{log_prefix}_rmse", rmse, on_step=False, on_epoch=True, prog_bar=False, batch_size=batch_size)
self.log(f"{log_prefix}_pointwise_loss", batch_output_dictionary["pointwise_loss"], on_step=False, on_epoch=True, prog_bar=False, batch_size=batch_size)
self.log(f"{log_prefix}_centered_curve_shape_loss", batch_output_dictionary["centered_curve_shape_loss"], on_step=False, on_epoch=True, prog_bar=False, batch_size=batch_size)
self.log(f"{log_prefix}_curve_offset_loss", batch_output_dictionary["curve_offset_loss"], on_step=False, on_epoch=True, prog_bar=False, batch_size=batch_size)
self.log(f"{log_prefix}_curve_amplitude_loss", batch_output_dictionary["curve_amplitude_loss"], on_step=False, on_epoch=True, prog_bar=False, batch_size=batch_size)
self.log(f"{log_prefix}_sparse_harmonic_shape_loss", batch_output_dictionary["sparse_harmonic_shape_loss"], on_step=False, on_epoch=True, prog_bar=False, batch_size=batch_size)
for probabilistic_metric_name in [
"interval_coverage",
"interval_width",
"quantile_crossing_rate",
"mean_sigma",
"mixture_weight_entropy",
"mixture_effective_components",
"mixture_mean_sigma",
"mixture_component_separation",
]:
probabilistic_metric_value = batch_output_dictionary.get(probabilistic_metric_name)
if isinstance(probabilistic_metric_value, torch.Tensor):
self.log(f"{log_prefix}_{probabilistic_metric_name}", probabilistic_metric_value, on_step=False, on_epoch=True, prog_bar=False, batch_size=batch_size)
# Log Structured-Only Diagnostics When Available
structured_prediction_tensor = batch_output_dictionary.get("structured_prediction_tensor")
if isinstance(structured_prediction_tensor, torch.Tensor):
structured_prediction_denormalized = self.denormalize_target_tensor(structured_prediction_tensor)
structured_mae = torch.mean(torch.abs(structured_prediction_denormalized - batch_output_dictionary["target_tensor"]))
structured_rmse = torch.sqrt(torch.mean(torch.square(structured_prediction_denormalized - batch_output_dictionary["target_tensor"])))
self.log(f"{log_prefix}_structured_mae", structured_mae, on_step=False, on_epoch=True, prog_bar=False, batch_size=batch_size)
self.log(f"{log_prefix}_structured_rmse", structured_rmse, on_step=False, on_epoch=True, prog_bar=False, batch_size=batch_size)
# Log Wave 3.1 Branch Diagnostics When Available
base_prediction_tensor = batch_output_dictionary.get("base_prediction_tensor")
if isinstance(base_prediction_tensor, torch.Tensor):
base_prediction_denormalized = self.denormalize_target_tensor(base_prediction_tensor)
base_mae = torch.mean(torch.abs(base_prediction_denormalized - batch_output_dictionary["target_tensor"]))
base_rmse = torch.sqrt(torch.mean(torch.square(base_prediction_denormalized - batch_output_dictionary["target_tensor"])))
self.log(f"{log_prefix}_base_mae", base_mae, on_step=False, on_epoch=True, prog_bar=False, batch_size=batch_size)
self.log(f"{log_prefix}_base_rmse", base_rmse, on_step=False, on_epoch=True, prog_bar=False, batch_size=batch_size)
residual_offset_prediction_tensor = batch_output_dictionary.get("residual_offset_prediction_tensor")
if isinstance(residual_offset_prediction_tensor, torch.Tensor):
residual_offset_denormalized = residual_offset_prediction_tensor * self.target_std
residual_offset_mean_abs = torch.mean(torch.abs(residual_offset_denormalized))
self.log(f"{log_prefix}_residual_offset_mean_abs", residual_offset_mean_abs, on_step=False, on_epoch=True, prog_bar=False, batch_size=batch_size)
return loss
[docs]
def training_step(self, batch_dictionary: dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor:
"""Run one Lightning training step and return the loss tensor."""
# Compute Loss And Metrics For Training Step
return self.compute_loss(batch_dictionary, "train")
[docs]
def validation_step(self, batch_dictionary: dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor:
"""Run one Lightning validation step and return the loss tensor."""
# Compute Loss And Metrics For Validation Step
return self.compute_loss(batch_dictionary, "val")
[docs]
def test_step(self, batch_dictionary: dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor:
"""Run one Lightning test step and return the loss tensor."""
# Compute Loss And Metrics For Test Step
return self.compute_loss(batch_dictionary, "test")