"""Latent-state hysteresis-aware network for Wave 4.4 TE probes."""
from __future__ import annotations
# Import PyTorch Utilities
import torch
import torch.nn as nn
import torch.nn.functional as torch_functional
# Import Project Model Utilities
from scripts.models.feedforward_network import FeedForwardNetwork
from scripts.models.feedforward_network import get_activation_module
from scripts.models.temporal_sequence_network import resolve_sequence_readout_tensor
[docs]
class CausalTemporalStateEncoder(nn.Module):
"""Compact causal temporal convolution state encoder."""
[docs]
def __init__(
self,
input_size: int,
hidden_size: int,
channel_size: list[int] | None = None,
kernel_size: int = 5,
activation_name: str = "GELU",
dropout_probability: float = 0.10,
) -> None:
"""Initialize the causal temporal state encoder."""
super().__init__()
# Validate Architecture Parameters
assert input_size > 0, f"Input Size must be positive | {input_size}"
assert hidden_size > 0, f"Hidden Size must be positive | {hidden_size}"
assert kernel_size > 0, f"Kernel Size must be positive | {kernel_size}"
assert dropout_probability >= 0.0, f"Dropout Probability must be non-negative | {dropout_probability}"
# Save Architecture Parameters
self.input_size = input_size
self.hidden_size = hidden_size
self.channel_size = list(channel_size or [hidden_size, hidden_size])
self.kernel_size = kernel_size
self.activation_name = activation_name
self.dropout_probability = dropout_probability
# Build Causal Convolution Stack
convolution_layer_list: list[nn.Module] = []
previous_channel_size = input_size
for current_channel_size in self.channel_size:
assert current_channel_size > 0, f"Channel Size must be positive | {current_channel_size}"
convolution_layer_list.append(nn.Conv1d(previous_channel_size, current_channel_size, kernel_size))
convolution_layer_list.append(get_activation_module(self.activation_name))
if self.dropout_probability > 0.0:
convolution_layer_list.append(nn.Dropout(self.dropout_probability))
previous_channel_size = current_channel_size
self.temporal_network = nn.ModuleList(convolution_layer_list)
self.output_projection = nn.Linear(previous_channel_size, hidden_size)
[docs]
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
"""Encode a batch-first causal sequence into one latent state."""
# Validate Input Tensor
assert input_tensor.ndim == 3, f"Causal TCN input must be rank-3 | {tuple(input_tensor.shape)}"
assert input_tensor.shape[-1] == self.input_size, (
f"Input feature mismatch | {input_tensor.shape[-1]} vs {self.input_size}"
)
# Run Conv1d Layers With Left Padding Only
channel_first_tensor = input_tensor.transpose(1, 2)
for layer in self.temporal_network:
if isinstance(layer, nn.Conv1d):
channel_first_tensor = torch_functional.pad(
channel_first_tensor,
(self.kernel_size - 1, 0),
)
channel_first_tensor = layer(channel_first_tensor)
sequence_output_tensor = channel_first_tensor.transpose(1, 2)
last_state_tensor = sequence_output_tensor[:, -1, :]
return self.output_projection(last_state_tensor)
[docs]
class LatentStateHysteresisNetwork(nn.Module):
"""Causal latent-state model with base, offset, and residual TE heads."""
[docs]
def __init__(
self,
input_size: int,
output_size: int = 1,
latent_encoder_type: str = "gru",
latent_hidden_size: int = 96,
latent_num_layers: int = 2,
latent_dropout_probability: float = 0.10,
latent_channel_size: list[int] | None = None,
latent_kernel_size: int = 5,
latent_activation_name: str = "GELU",
readout_position: str = "last",
base_hidden_size: list[int] | None = None,
head_hidden_size: list[int] | None = None,
head_activation_name: str = "GELU",
head_dropout_probability: float = 0.05,
use_layer_norm: bool = True,
offset_scale: float = 1.0,
residual_scale: float = 1.0,
) -> None:
"""Initialize the latent-state hysteresis-aware TE probe."""
super().__init__()
# Validate Architecture Parameters
normalized_encoder_type = latent_encoder_type.strip().lower()
assert normalized_encoder_type in ["gru", "causal_tcn"], f"Unsupported Latent Encoder Type | {latent_encoder_type}"
assert input_size >= 4, f"Input Size must expose angle and TE operating features | {input_size}"
assert output_size > 0, f"Output Size must be positive | {output_size}"
assert latent_hidden_size > 0, f"Latent Hidden Size must be positive | {latent_hidden_size}"
assert latent_num_layers > 0, f"Latent Num Layers must be positive | {latent_num_layers}"
assert offset_scale > 0.0, f"Offset Scale must be positive | {offset_scale}"
assert residual_scale > 0.0, f"Residual Scale must be positive | {residual_scale}"
# Save Architecture Parameters
self.input_size = input_size
self.output_size = output_size
self.latent_encoder_type = normalized_encoder_type
self.latent_hidden_size = latent_hidden_size
self.latent_num_layers = latent_num_layers
self.readout_position = readout_position
self.offset_scale = float(offset_scale)
self.residual_scale = float(residual_scale)
# Build Causal Latent-State Encoder
recurrent_dropout_probability = latent_dropout_probability if latent_num_layers > 1 else 0.0
if self.latent_encoder_type == "gru":
self.latent_encoder = nn.GRU(
input_size=input_size,
hidden_size=latent_hidden_size,
num_layers=latent_num_layers,
batch_first=True,
dropout=recurrent_dropout_probability,
bidirectional=False,
)
else:
self.latent_encoder = CausalTemporalStateEncoder(
input_size=input_size,
hidden_size=latent_hidden_size,
channel_size=latent_channel_size,
kernel_size=latent_kernel_size,
activation_name=latent_activation_name,
dropout_probability=latent_dropout_probability,
)
# Build Point Base And Latent Heads
self.base_branch = FeedForwardNetwork(
input_size=input_size,
hidden_size=list(base_hidden_size or [96, 64]),
output_size=output_size,
activation_name=head_activation_name,
dropout_probability=head_dropout_probability,
use_layer_norm=use_layer_norm,
)
head_input_size = input_size + latent_hidden_size
self.offset_head = FeedForwardNetwork(
input_size=head_input_size,
hidden_size=list(head_hidden_size or [96, 64]),
output_size=output_size,
activation_name=head_activation_name,
dropout_probability=head_dropout_probability,
use_layer_norm=use_layer_norm,
)
self.residual_head = FeedForwardNetwork(
input_size=head_input_size,
hidden_size=list(head_hidden_size or [96, 64]),
output_size=output_size,
activation_name=head_activation_name,
dropout_probability=head_dropout_probability,
use_layer_norm=use_layer_norm,
)
[docs]
def resolve_readout_feature_tensor(self, sequence_tensor: torch.Tensor) -> torch.Tensor:
"""Extract the current operating-state feature tensor."""
# Validate Sequence Tensor
assert sequence_tensor.ndim == 3, f"Sequence Tensor must be rank-3 | {tuple(sequence_tensor.shape)}"
assert sequence_tensor.shape[-1] == self.input_size, (
f"Input feature mismatch | {sequence_tensor.shape[-1]} vs {self.input_size}"
)
return resolve_sequence_readout_tensor(sequence_tensor, self.readout_position)
[docs]
def encode_latent_state(self, normalized_input_tensor: torch.Tensor) -> torch.Tensor:
"""Encode the causal operating-history window into one latent state."""
# Dispatch Encoder Type
if self.latent_encoder_type == "gru":
_sequence_tensor, hidden_state_tensor = self.latent_encoder(normalized_input_tensor)
return hidden_state_tensor[-1]
return self.latent_encoder(normalized_input_tensor)
[docs]
def compute_auxiliary_output_dictionary(
self,
input_tensor: torch.Tensor,
normalized_input_tensor: torch.Tensor,
) -> dict[str, torch.Tensor]:
"""Expose base, latent, offset, residual, and final prediction tensors."""
# Validate Sequence Inputs
assert input_tensor.ndim == 3, f"Input Tensor must be rank-3 | {tuple(input_tensor.shape)}"
assert normalized_input_tensor.ndim == 3, (
f"Normalized Input Tensor must be rank-3 | {tuple(normalized_input_tensor.shape)}"
)
assert input_tensor.shape == normalized_input_tensor.shape, (
f"Raw and normalized sequence shapes must match | {tuple(input_tensor.shape)} vs "
f"{tuple(normalized_input_tensor.shape)}"
)
# Compute Causal Readout And Latent State
readout_normalized_input_tensor = self.resolve_readout_feature_tensor(normalized_input_tensor)
latent_state_tensor = self.encode_latent_state(normalized_input_tensor)
head_input_tensor = torch.cat([readout_normalized_input_tensor, latent_state_tensor], dim=-1)
# Predict Base, Offset, Residual, And Final TE
base_prediction_tensor = self.base_branch(readout_normalized_input_tensor)
offset_prediction_tensor = self.offset_head(head_input_tensor) * self.offset_scale
residual_prediction_tensor = self.residual_head(head_input_tensor) * self.residual_scale
final_prediction_tensor = base_prediction_tensor + offset_prediction_tensor + residual_prediction_tensor
return {
"base_prediction_tensor": base_prediction_tensor,
"latent_state_tensor": latent_state_tensor,
"residual_offset_prediction_tensor": offset_prediction_tensor,
"hysteresis_residual_prediction_tensor": residual_prediction_tensor,
"prediction_tensor": final_prediction_tensor,
}
[docs]
def forward_with_input_context(self, input_tensor: torch.Tensor, normalized_input_tensor: torch.Tensor) -> torch.Tensor:
"""Predict normalized TE using raw context and normalized model input."""
return self.compute_auxiliary_output_dictionary(input_tensor, normalized_input_tensor)["prediction_tensor"]
[docs]
def forward(self, normalized_input_tensor: torch.Tensor) -> torch.Tensor:
"""Run inference from normalized sequence inputs."""
return self.compute_auxiliary_output_dictionary(
normalized_input_tensor,
normalized_input_tensor,
)["prediction_tensor"]