"""Temporal sequence networks for Wave 2.1 TE regression candidates."""
from __future__ import annotations
# Import PyTorch Utilities
import torch
import torch.nn as nn
# Import Project Model Utilities
from scripts.models.feedforward_network import get_activation_module
[docs]
def resolve_sequence_readout_tensor(sequence_tensor: torch.Tensor, readout_position: str) -> torch.Tensor:
"""Select one timestep representation from a batched sequence tensor."""
# Validate Sequence Tensor
assert sequence_tensor.ndim == 3, f"Sequence Tensor must be rank-3 | {tuple(sequence_tensor.shape)}"
# Resolve Readout Position
normalized_readout_position = readout_position.strip().lower()
assert normalized_readout_position in ["center", "last"], f"Unsupported Readout Position | {readout_position}"
if normalized_readout_position == "center":
center_index = sequence_tensor.shape[1] // 2
return sequence_tensor[:, center_index, :]
return sequence_tensor[:, -1, :]
[docs]
class TemporalConvolutionNetwork(nn.Module):
"""Causal-free temporal convolutional regressor for TE sequence windows."""
[docs]
def __init__(
self,
input_size: int,
channel_size: list[int],
output_size: int = 1,
kernel_size: int = 5,
activation_name: str = "GELU",
dropout_probability: float = 0.10,
readout_position: str = "center",
) -> None:
"""Initialize the temporal convolutional regression backbone."""
super().__init__()
# Validate Architecture Parameters
assert input_size > 0, f"Input Size must be positive | {input_size}"
assert output_size > 0, f"Output Size must be positive | {output_size}"
assert len(channel_size) > 0, "Channel Size list must contain at least one layer"
assert kernel_size > 0 and kernel_size % 2 == 1, f"Kernel Size must be a positive odd integer | {kernel_size}"
assert dropout_probability >= 0.0, f"Dropout Probability must be non-negative | {dropout_probability}"
# Save Architecture Parameters
self.input_size = input_size
self.channel_size = list(channel_size)
self.output_size = output_size
self.kernel_size = kernel_size
self.activation_name = activation_name
self.dropout_probability = dropout_probability
self.readout_position = readout_position
# Build Temporal Convolution Stack
convolution_layer_list: list[nn.Module] = []
previous_channel_size = input_size
padding_size = kernel_size // 2
for current_channel_size in self.channel_size:
assert current_channel_size > 0, f"Channel Size must be positive | {current_channel_size}"
convolution_layer_list.append(nn.Conv1d(previous_channel_size, current_channel_size, kernel_size, padding=padding_size))
convolution_layer_list.append(get_activation_module(self.activation_name))
if self.dropout_probability > 0.0: convolution_layer_list.append(nn.Dropout(self.dropout_probability))
previous_channel_size = current_channel_size
self.temporal_network = nn.Sequential(*convolution_layer_list)
self.output_layer = nn.Linear(previous_channel_size, output_size)
[docs]
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
"""Run the temporal convolutional network on rank-3 input windows."""
# Validate Input Tensor
assert input_tensor.ndim == 3, f"Temporal Convolution input must be rank-3 | {tuple(input_tensor.shape)}"
assert input_tensor.shape[-1] == self.input_size, (
f"Input feature mismatch | {input_tensor.shape[-1]} vs {self.input_size}"
)
# Conv1d Expects Batch, Channel, Sequence
channel_first_input_tensor = input_tensor.transpose(1, 2)
channel_first_output_tensor = self.temporal_network(channel_first_input_tensor)
sequence_output_tensor = channel_first_output_tensor.transpose(1, 2)
readout_tensor = resolve_sequence_readout_tensor(sequence_output_tensor, self.readout_position)
return self.output_layer(readout_tensor)
[docs]
class RecurrentSequenceNetwork(nn.Module):
"""GRU or LSTM sequence regressor with an explicit temporal readout."""
[docs]
def __init__(
self,
recurrent_type: str,
input_size: int,
hidden_size: int,
output_size: int = 1,
num_layers: int = 2,
dropout_probability: float = 0.10,
bidirectional: bool = False,
readout_position: str = "center",
) -> None:
"""Initialize a recurrent sequence regression backbone."""
super().__init__()
# Validate Architecture Parameters
normalized_recurrent_type = recurrent_type.strip().lower()
assert normalized_recurrent_type in ["gru", "lstm"], f"Unsupported Recurrent Type | {recurrent_type}"
assert input_size > 0, f"Input Size must be positive | {input_size}"
assert hidden_size > 0, f"Hidden Size must be positive | {hidden_size}"
assert output_size > 0, f"Output Size must be positive | {output_size}"
assert num_layers > 0, f"Num Layers must be positive | {num_layers}"
assert dropout_probability >= 0.0, f"Dropout Probability must be non-negative | {dropout_probability}"
# Save Architecture Parameters
self.recurrent_type = normalized_recurrent_type
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
self.num_layers = num_layers
self.dropout_probability = dropout_probability
self.bidirectional = bidirectional
self.readout_position = readout_position
# PyTorch applies recurrent dropout only between stacked layers.
recurrent_dropout_probability = dropout_probability if num_layers > 1 else 0.0
recurrent_class = nn.GRU if self.recurrent_type == "gru" else nn.LSTM
self.recurrent_network = recurrent_class(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True,
dropout=recurrent_dropout_probability,
bidirectional=bidirectional,
)
recurrent_output_size = hidden_size * (2 if bidirectional else 1)
self.output_dropout = nn.Dropout(dropout_probability) if dropout_probability > 0.0 else nn.Identity()
self.output_layer = nn.Linear(recurrent_output_size, output_size)
[docs]
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
"""Run the recurrent sequence network on rank-3 input windows."""
# Validate Input Tensor
assert input_tensor.ndim == 3, f"Recurrent input must be rank-3 | {tuple(input_tensor.shape)}"
assert input_tensor.shape[-1] == self.input_size, (
f"Input feature mismatch | {input_tensor.shape[-1]} vs {self.input_size}"
)
# Read Recurrent Output Sequence
recurrent_output_tensor, _ = self.recurrent_network(input_tensor)
readout_tensor = resolve_sequence_readout_tensor(recurrent_output_tensor, self.readout_position)
dropped_readout_tensor = self.output_dropout(readout_tensor)
return self.output_layer(dropped_readout_tensor)