Transmission Error DataModule
This page documents the Lightning datamodule and batch-collation helpers used by the TE regression training workflow.
Lightning data module and batch utilities for TE regression training.
- class scripts.training.transmission_error_datamodule.NormalizationStatistics(input_feature_mean, input_feature_std, target_mean, target_std)[source]
Feature and target normalization tensors for TE regression.
- Parameters:
input_feature_mean (Tensor)
input_feature_std (Tensor)
target_mean (Tensor)
target_std (Tensor)
- class scripts.training.transmission_error_datamodule.DatasetSplitSummary(train_curve_count, validation_curve_count, test_curve_count)[source]
Curve counts for the train, validation, and test splits.
- Parameters:
train_curve_count (int)
validation_curve_count (int)
test_curve_count (int)
- scripts.training.transmission_error_datamodule.move_batch_tensor_collection_to_device(batch_value, device, use_non_blocking_transfer=False)[source]
Recursively move a nested batch structure to the requested device.
- Parameters:
batch_value (Any) – Tensor, list, tuple, dictionary, or scalar batch value.
device (device) – Target device for tensor transfer.
use_non_blocking_transfer (bool) – Whether CUDA transfers should request non-blocking behavior when possible.
- Returns:
Batch structure mirrored on the target device.
- Return type:
Any
- scripts.training.transmission_error_datamodule.extract_point_tensor_from_curve_sample(curve_sample_dictionary, point_stride=1, maximum_points_per_curve=None)[source]
Convert one curve sample into a point-sampled tensor dictionary.
- Parameters:
curve_sample_dictionary (dict[str, Any]) – Dataset sample containing full-curve tensors.
point_stride (int) – Step used to subsample curve points.
maximum_points_per_curve (int | None) – Optional hard cap on sampled points per curve.
- Returns:
Point-level tensors for model input, target, and angular position.
- Return type:
dict[str, torch.Tensor]
- scripts.training.transmission_error_datamodule.collate_transmission_error_points(batch_dictionary_list, point_stride=1, maximum_points_per_curve=None, shuffle_points=True)[source]
Collate curve samples into one point-level training batch.
- Parameters:
batch_dictionary_list (list[dict[str, Any]]) – Curve-level samples produced by the dataset.
point_stride (int) – Step used to subsample points from each curve.
maximum_points_per_curve (int | None) – Optional cap on sampled points per curve.
shuffle_points (bool) – Whether to shuffle the concatenated point batch.
- Returns:
Batch dictionary containing concatenated tensors and lightweight curve-level metadata.
- Return type:
dict[str, Any]
- class scripts.training.transmission_error_datamodule.TransmissionErrorDataModule(dataset_config_path, curve_batch_size=2, point_stride=20, maximum_points_per_curve=None, num_workers=0, pin_memory=False, use_non_blocking_transfer=False)[source]
LightningDataModule for TE curve splits, sampling, and normalization.
- Parameters:
dataset_config_path (str | Path)
curve_batch_size (int)
point_stride (int)
maximum_points_per_curve (int | None)
num_workers (int)
pin_memory (bool)
use_non_blocking_transfer (bool)
- __init__(dataset_config_path, curve_batch_size=2, point_stride=20, maximum_points_per_curve=None, num_workers=0, pin_memory=False, use_non_blocking_transfer=False)[source]
Initialize the reusable TE training datamodule.
- Parameters:
dataset_config_path (str | Path) – Dataset YAML configuration path.
curve_batch_size (int) – Number of curves loaded per dataloader batch.
point_stride (int) – Subsampling stride applied inside each curve.
maximum_points_per_curve (int | None) – Optional cap on sampled points per curve.
num_workers (int) – PyTorch dataloader worker count.
pin_memory (bool) – Whether dataloaders should pin host memory.
use_non_blocking_transfer (bool) – Whether device transfer should request non-blocking CUDA copies when possible.
- Return type:
None
- setup(stage=None)[source]
Create dataset splits and compute normalization statistics.
- Parameters:
stage (str | None) – Lightning stage hint. The current implementation uses the same prepared splits for fit, validation, and test access.
- Return type:
None
- compute_normalization_statistics(curve_dataset)[source]
Compute point-level normalization statistics from the train split.
- Parameters:
curve_dataset (TransmissionErrorCurveDataset) – Dataset used to accumulate sampled point statistics.
- Returns:
Mean and standard deviation tensors for model inputs and targets.
- Return type:
- get_input_feature_dim()[source]
Return the resolved input feature dimension after setup.
- Return type:
int
- get_target_feature_dim()[source]
Return the resolved target feature dimension after setup.
- Return type:
int
- get_normalization_statistics()[source]
Return the cached normalization statistics after setup.
- Return type:
- get_dataset_split_summary()[source]
Return the current split sizes in number of curves.
- Return type:
- train_dataloader()[source]
Build the training dataloader with point-level collation.
- Return type:
DataLoader
- val_dataloader()[source]
Build the validation dataloader with deterministic point ordering.
- Return type:
DataLoader
- test_dataloader()[source]
Build the test dataloader with deterministic point ordering.
- Return type:
DataLoader
- transfer_batch_to_device(batch, device, dataloader_idx)[source]
Move a collated batch to the target accelerator device.
- Parameters:
batch (Any) – Batch emitted by one of the TE dataloaders.
device (device) – Target Lightning device.
dataloader_idx (int) – Dataloader index supplied by Lightning.
- Returns:
Batch moved to the requested device.
- Return type:
Any