Transmission Error Regression Module

This page documents the Lightning regression wrapper that standardizes normalization, loss computation, metric logging, and optimizer configuration for TE models.

Lightning regression module for normalized TE prediction workflows.

class scripts.training.transmission_error_regression_module.TransmissionErrorRegressionModule(regression_model, input_feature_dim, target_feature_dim, learning_rate=1.0e-3, weight_decay=1.0e-4, normalization_statistics=None, loss_configuration=None)[source]

LightningModule that wraps TE backbones, normalization, and metrics.

Parameters:
  • regression_model (nn.Module)

  • input_feature_dim (int)

  • target_feature_dim (int)

  • learning_rate (float)

  • weight_decay (float)

  • normalization_statistics (NormalizationStatistics | None)

  • loss_configuration (dict | None)

__init__(regression_model, input_feature_dim, target_feature_dim, learning_rate=1.0e-3, weight_decay=1.0e-4, normalization_statistics=None, loss_configuration=None)[source]

Initialize the TE regression LightningModule.

Parameters:
  • regression_model (Module) – Backbone model operating on normalized inputs.

  • input_feature_dim (int) – Number of model input features.

  • target_feature_dim (int) – Number of regression targets.

  • learning_rate (float) – AdamW learning rate.

  • weight_decay (float) – AdamW weight decay.

  • normalization_statistics (NormalizationStatistics | None) – Optional normalization tensors loaded at construction time.

  • loss_configuration (dict | None) – Optional composite loss configuration used by curve-aware campaign branches.

Return type:

None

compute_pointwise_prediction_loss(prediction_tensor, target_tensor)[source]

Compute the configured normalized-space pointwise regression loss.

Parameters:
  • prediction_tensor (Tensor)

  • target_tensor (Tensor)

Return type:

Tensor

compute_quantile_pinball_loss(prediction_tensor, target_tensor)[source]

Compute multi-quantile pinball loss in normalized target space.

Parameters:
  • prediction_tensor (Tensor)

  • target_tensor (Tensor)

Return type:

Tensor

compute_gaussian_negative_log_likelihood_loss(prediction_tensor, target_tensor)[source]

Compute guarded Gaussian NLL in normalized target space.

Parameters:
  • prediction_tensor (Tensor)

  • target_tensor (Tensor)

Return type:

Tensor

split_mixture_density_output_tensor(prediction_tensor)[source]

Split MDN output into logits, means, and guarded log-sigmas.

Parameters:

prediction_tensor (Tensor)

Return type:

tuple[Tensor, Tensor, Tensor]

compute_mixture_density_negative_log_likelihood_loss(prediction_tensor, target_tensor)[source]

Compute stable Gaussian-mixture NLL in normalized target space.

Parameters:
  • prediction_tensor (Tensor)

  • target_tensor (Tensor)

Return type:

Tensor

compute_mixture_expectation_tensor(model_output_tensor)[source]

Compute deterministic MDN playback as the mixture expectation.

Parameters:

model_output_tensor (Tensor)

Return type:

Tensor

extract_deterministic_prediction_tensor(model_output_tensor)[source]

Select the deterministic channel used for MAE/RMSE and TE curve-verification playback.

Parameters:

model_output_tensor (Tensor)

Return type:

Tensor

compute_probabilistic_metric_dictionary(model_output_tensor, normalized_target_tensor)[source]

Compute optional uncertainty diagnostics from raw model outputs.

Parameters:
  • model_output_tensor (Tensor)

  • normalized_target_tensor (Tensor)

Return type:

dict[str, Tensor]

compute_curve_aware_loss_dictionary(batch_dictionary, batch_output_dictionary, pointwise_loss)[source]

Compute optional curve-aware loss terms from collated curve groups.

Parameters:
  • batch_dictionary (dict[str, Tensor])

  • batch_output_dictionary (dict[str, Tensor])

  • pointwise_loss (Tensor)

Return type:

dict[str, Tensor]

compute_sparse_harmonic_shape_loss(angular_position_tensor, centered_prediction_tensor, centered_target_tensor, harmonic_index_list)[source]

Compare centered prediction and truth on sparse sine/cosine terms.

Parameters:
  • angular_position_tensor (Tensor)

  • centered_prediction_tensor (Tensor)

  • centered_target_tensor (Tensor)

  • harmonic_index_list (list[int])

Return type:

Tensor

set_normalization_statistics(normalization_statistics)[source]

Load normalization tensors into the module buffers.

Parameters:

normalization_statistics (NormalizationStatistics) – Input and target statistics computed from the training split.

Return type:

None

normalize_input_tensor(input_tensor)[source]

Normalize model inputs with the registered training statistics.

Parameters:

input_tensor (Tensor)

Return type:

Tensor

normalize_target_tensor(target_tensor)[source]

Normalize regression targets with the registered statistics.

Parameters:

target_tensor (Tensor)

Return type:

Tensor

denormalize_target_tensor(normalized_target_tensor)[source]

Map normalized target predictions back to physical TE units.

Parameters:

normalized_target_tensor (Tensor)

Return type:

Tensor

forward(normalized_input_tensor)[source]

Run the backbone on normalized inputs only.

Parameters:

normalized_input_tensor (Tensor)

Return type:

Tensor

forward_regression_model(input_tensor, normalized_input_tensor)[source]

Run the backbone while supporting structured auxiliary outputs.

Parameters:
  • input_tensor (Tensor) – Raw input tensor before normalization-aware routing.

  • normalized_input_tensor (Tensor) – Normalized model input tensor.

Returns:

Prediction tensor in normalized space plus any auxiliary structured outputs emitted by the backbone.

Return type:

tuple[torch.Tensor, dict[str, torch.Tensor]]

compute_batch_outputs(batch_dictionary)[source]

Compute normalized loss terms and denormalized TE metrics.

Parameters:

batch_dictionary (dict[str, Tensor]) – Point-level batch emitted by the datamodule.

Returns:

Batch outputs including normalized predictions, denormalized predictions, and metric tensors.

Return type:

dict[str, torch.Tensor]

compute_loss(batch_dictionary, log_prefix)[source]

Compute and log one split-specific loss bundle.

Parameters:
  • batch_dictionary (dict[str, Tensor]) – Point-level batch emitted by the datamodule.

  • log_prefix (str) – Prefix used for Lightning metric names such as train, val, or test.

Returns:

Scalar normalized-space MSE loss.

Return type:

torch.Tensor

training_step(batch_dictionary, batch_idx)[source]

Run one Lightning training step and return the loss tensor.

Parameters:
  • batch_dictionary (dict[str, Tensor])

  • batch_idx (int)

Return type:

Tensor

validation_step(batch_dictionary, batch_idx)[source]

Run one Lightning validation step and return the loss tensor.

Parameters:
  • batch_dictionary (dict[str, Tensor])

  • batch_idx (int)

Return type:

Tensor

test_step(batch_dictionary, batch_idx)[source]

Run one Lightning test step and return the loss tensor.

Parameters:
  • batch_dictionary (dict[str, Tensor])

  • batch_idx (int)

Return type:

Tensor

configure_optimizers()[source]

Configure the AdamW optimizer for TE regression training.