"""Residual harmonic temporal sequence networks for Wave 2.3 TE candidates."""
from __future__ import annotations
# Import PyTorch Utilities
import torch
import torch.nn as nn
# Import Project Models
from scripts.models.harmonic_regression import HarmonicRegression
from scripts.models.temporal_sequence_network import RecurrentSequenceNetwork
from scripts.models.temporal_sequence_network import resolve_sequence_readout_tensor
[docs]
class ResidualHarmonicTemporalSequenceNetwork(nn.Module):
"""Sequence TE model with harmonic base prediction plus temporal residual."""
[docs]
def __init__(
self,
temporal_model_type: str,
input_size: int,
output_size: int = 1,
harmonic_order: int = 12,
coefficient_mode: str = "static",
harmonic_index_list: list[int] | None = None,
hidden_size: int = 128,
num_layers: int = 2,
dropout_probability: float = 0.10,
bidirectional: bool = False,
readout_position: str = "center",
freeze_structured_branch: bool = False,
) -> None:
"""Initialize one residual harmonic recurrent sequence model.
Args:
temporal_model_type: Temporal residual selector. Supported values
are `gru_sequence` and `lstm_sequence`.
input_size: Raw sequence feature count, including angular position
as the first feature.
output_size: Regression target count.
harmonic_order: Contiguous harmonic order used when no explicit
harmonic index list is provided.
coefficient_mode: Harmonic coefficient parameterization mode passed
to the structured branch.
harmonic_index_list: Optional explicit non-negative harmonic list.
hidden_size: Recurrent hidden size for the residual branch.
num_layers: Recurrent layer count.
dropout_probability: Dropout probability used by the recurrent
residual branch.
bidirectional: Whether the recurrent residual branch is
bidirectional.
readout_position: Sequence readout position used by both the
residual branch and center-vector extraction.
freeze_structured_branch: Whether to freeze the structured
harmonic branch parameters during optimization.
"""
super().__init__()
# Validate Architecture Parameters
normalized_temporal_model_type = temporal_model_type.strip().lower()
supported_temporal_type_list = ["gru_sequence", "lstm_sequence"]
assert normalized_temporal_model_type in supported_temporal_type_list, (
f"Unsupported Residual Harmonic Temporal Model Type | {temporal_model_type}"
)
assert input_size >= 4, f"Input Size must expose angle and TE operating-condition features | {input_size}"
assert output_size > 0, f"Output Size must be positive | {output_size}"
assert harmonic_order > 0, f"Harmonic Order must be positive | {harmonic_order}"
# Save Architecture Parameters
self.temporal_model_type = normalized_temporal_model_type
self.input_size = input_size
self.output_size = output_size
self.harmonic_order = harmonic_order
self.coefficient_mode = coefficient_mode
self.readout_position = readout_position
self.freeze_structured_branch = freeze_structured_branch
# Initialize Structured Harmonic Branch
self.structured_branch = HarmonicRegression(
input_size=input_size,
output_size=output_size,
harmonic_order=harmonic_order,
coefficient_mode=coefficient_mode,
harmonic_index_list=harmonic_index_list,
)
self.harmonic_index_list = self.structured_branch.harmonic_index_list
self.positive_harmonic_index_list = self.structured_branch.positive_harmonic_index_list
# Optionally Freeze Structured Parameters
if self.freeze_structured_branch:
# Disable Structured Gradients
for structured_parameter in self.structured_branch.parameters():
structured_parameter.requires_grad = False
# Initialize Temporal Residual Branch
recurrent_type = "gru" if self.temporal_model_type == "gru_sequence" else "lstm"
self.residual_branch = RecurrentSequenceNetwork(
recurrent_type=recurrent_type,
input_size=input_size,
hidden_size=hidden_size,
output_size=output_size,
num_layers=num_layers,
dropout_probability=dropout_probability,
bidirectional=bidirectional,
readout_position=readout_position,
)
[docs]
def resolve_readout_feature_tensor(self, sequence_tensor: torch.Tensor) -> torch.Tensor:
"""Extract the rank-2 feature tensor used by the structured branch."""
# Validate Sequence Tensor
assert sequence_tensor.ndim == 3, f"Sequence Tensor must be rank-3 | {tuple(sequence_tensor.shape)}"
assert sequence_tensor.shape[-1] == self.input_size, (
f"Input feature mismatch | {sequence_tensor.shape[-1]} vs {self.input_size}"
)
return resolve_sequence_readout_tensor(sequence_tensor, self.readout_position)
[docs]
def compute_auxiliary_output_dictionary(
self,
input_tensor: torch.Tensor,
normalized_input_tensor: torch.Tensor,
) -> dict[str, torch.Tensor]:
"""Expose branch-level outputs for diagnostics and metric logging.
Args:
input_tensor: Raw rank-3 sequence tensor whose first feature is
physical angular position in degrees.
normalized_input_tensor: Normalized rank-3 sequence tensor used by
the temporal residual branch.
Returns:
dict[str, torch.Tensor]: Structured branch output, residual branch
output, and final combined prediction tensor.
"""
# Validate Sequence Inputs
assert input_tensor.ndim == 3, f"Input Tensor must be rank-3 | {tuple(input_tensor.shape)}"
assert normalized_input_tensor.ndim == 3, (
f"Normalized Input Tensor must be rank-3 | {tuple(normalized_input_tensor.shape)}"
)
assert input_tensor.shape == normalized_input_tensor.shape, (
f"Raw and normalized sequence shapes must match | {tuple(input_tensor.shape)} vs "
f"{tuple(normalized_input_tensor.shape)}"
)
# Extract Structured Branch Inputs At The Configured Readout Position
readout_input_tensor = self.resolve_readout_feature_tensor(input_tensor)
readout_normalized_input_tensor = self.resolve_readout_feature_tensor(normalized_input_tensor)
# Forward Pass Through Structured Harmonic Branch
structured_prediction_tensor = self.structured_branch.forward_with_input_context(
readout_input_tensor,
readout_normalized_input_tensor,
)
# Forward Pass Through Temporal Residual Branch
residual_prediction_tensor = self.residual_branch(normalized_input_tensor)
# Return Branch-Level Diagnostics
return {
"structured_prediction_tensor": structured_prediction_tensor,
"residual_prediction_tensor": residual_prediction_tensor,
"prediction_tensor": structured_prediction_tensor + residual_prediction_tensor,
}
[docs]
def forward_with_input_context(self, input_tensor: torch.Tensor, normalized_input_tensor: torch.Tensor) -> torch.Tensor:
"""Predict TE from a harmonic base plus recurrent residual correction."""
return self.compute_auxiliary_output_dictionary(input_tensor, normalized_input_tensor)["prediction_tensor"]