Transmission Error DataModule

This page documents the Lightning datamodule and batch-collation helpers used by the TE regression training workflow.

Lightning data module and batch utilities for TE regression training.

class scripts.training.transmission_error_datamodule.NormalizationStatistics(input_feature_mean, input_feature_std, target_mean, target_std)[source]

Feature and target normalization tensors for TE regression.

Parameters:
  • input_feature_mean (Tensor)

  • input_feature_std (Tensor)

  • target_mean (Tensor)

  • target_std (Tensor)

class scripts.training.transmission_error_datamodule.DatasetSplitSummary(train_curve_count, validation_curve_count, test_curve_count)[source]

Curve counts for the train, validation, and test splits.

Parameters:
  • train_curve_count (int)

  • validation_curve_count (int)

  • test_curve_count (int)

scripts.training.transmission_error_datamodule.move_batch_tensor_collection_to_device(batch_value, device, use_non_blocking_transfer=False)[source]

Recursively move a nested batch structure to the requested device.

Parameters:
  • batch_value (Any) – Tensor, list, tuple, dictionary, or scalar batch value.

  • device (device) – Target device for tensor transfer.

  • use_non_blocking_transfer (bool) – Whether CUDA transfers should request non-blocking behavior when possible.

Returns:

Batch structure mirrored on the target device.

Return type:

Any

scripts.training.transmission_error_datamodule.extract_point_tensor_from_curve_sample(curve_sample_dictionary, point_stride=1, maximum_points_per_curve=None)[source]

Convert one curve sample into a point-sampled tensor dictionary.

Parameters:
  • curve_sample_dictionary (dict[str, Any]) – Dataset sample containing full-curve tensors.

  • point_stride (int) – Step used to subsample curve points.

  • maximum_points_per_curve (int | None) – Optional hard cap on sampled points per curve.

Returns:

Point-level tensors for model input, target, and angular position.

Return type:

dict[str, torch.Tensor]

scripts.training.transmission_error_datamodule.collate_transmission_error_points(batch_dictionary_list, point_stride=1, maximum_points_per_curve=None, shuffle_points=True)[source]

Collate curve samples into one point-level training batch.

Parameters:
  • batch_dictionary_list (list[dict[str, Any]]) – Curve-level samples produced by the dataset.

  • point_stride (int) – Step used to subsample points from each curve.

  • maximum_points_per_curve (int | None) – Optional cap on sampled points per curve.

  • shuffle_points (bool) – Whether to shuffle the concatenated point batch.

Returns:

Batch dictionary containing concatenated tensors and lightweight curve-level metadata.

Return type:

dict[str, Any]

class scripts.training.transmission_error_datamodule.TransmissionErrorDataModule(dataset_config_path, curve_batch_size=2, point_stride=20, maximum_points_per_curve=None, num_workers=0, pin_memory=False, use_non_blocking_transfer=False)[source]

LightningDataModule for TE curve splits, sampling, and normalization.

Parameters:
  • dataset_config_path (str | Path)

  • curve_batch_size (int)

  • point_stride (int)

  • maximum_points_per_curve (int | None)

  • num_workers (int)

  • pin_memory (bool)

  • use_non_blocking_transfer (bool)

__init__(dataset_config_path, curve_batch_size=2, point_stride=20, maximum_points_per_curve=None, num_workers=0, pin_memory=False, use_non_blocking_transfer=False)[source]

Initialize the reusable TE training datamodule.

Parameters:
  • dataset_config_path (str | Path) – Dataset YAML configuration path.

  • curve_batch_size (int) – Number of curves loaded per dataloader batch.

  • point_stride (int) – Subsampling stride applied inside each curve.

  • maximum_points_per_curve (int | None) – Optional cap on sampled points per curve.

  • num_workers (int) – PyTorch dataloader worker count.

  • pin_memory (bool) – Whether dataloaders should pin host memory.

  • use_non_blocking_transfer (bool) – Whether device transfer should request non-blocking CUDA copies when possible.

Return type:

None

setup(stage=None)[source]

Create dataset splits and compute normalization statistics.

Parameters:

stage (str | None) – Lightning stage hint. The current implementation uses the same prepared splits for fit, validation, and test access.

Return type:

None

compute_normalization_statistics(curve_dataset)[source]

Compute point-level normalization statistics from the train split.

Parameters:

curve_dataset (TransmissionErrorCurveDataset) – Dataset used to accumulate sampled point statistics.

Returns:

Mean and standard deviation tensors for model inputs and targets.

Return type:

NormalizationStatistics

get_input_feature_dim()[source]

Return the resolved input feature dimension after setup.

Return type:

int

get_target_feature_dim()[source]

Return the resolved target feature dimension after setup.

Return type:

int

get_normalization_statistics()[source]

Return the cached normalization statistics after setup.

Return type:

NormalizationStatistics

get_dataset_split_summary()[source]

Return the current split sizes in number of curves.

Return type:

DatasetSplitSummary

train_dataloader()[source]

Build the training dataloader with point-level collation.

Return type:

DataLoader

val_dataloader()[source]

Build the validation dataloader with deterministic point ordering.

Return type:

DataLoader

test_dataloader()[source]

Build the test dataloader with deterministic point ordering.

Return type:

DataLoader

transfer_batch_to_device(batch, device, dataloader_idx)[source]

Move a collated batch to the target accelerator device.

Parameters:
  • batch (Any) – Batch emitted by one of the TE dataloaders.

  • device (device) – Target Lightning device.

  • dataloader_idx (int) – Dataloader index supplied by Lightning.

Returns:

Batch moved to the requested device.

Return type:

Any