Source code for scripts.models.wave3_grouped_harmonic_heads_network

"""Embryonic Wave 5.1 grouped harmonic-heads TE model."""

from __future__ import annotations

# Import Typing Utilities
from collections.abc import Sequence

# Import PyTorch Utilities
import torch
import torch.nn as nn

# Import Project Models
from scripts.models.feedforward_network import FeedForwardNetwork
from scripts.models.harmonic_regression import HarmonicRegression
from scripts.models.temporal_sequence_network import resolve_sequence_readout_tensor


[docs] class Wave3GroupedHarmonicHeadsNetwork(nn.Module): """Grouped harmonic heads plus residual correction skeleton."""
[docs] def __init__( self, input_size: int, output_size: int = 1, harmonic_order: int = 240, coefficient_mode: str = "linear_conditioned", low_order_harmonic_index_list: Sequence[int] | None = None, stable_middle_harmonic_index_list: Sequence[int] | None = None, high_order_harmonic_index_list: Sequence[int] | None = None, residual_hidden_size: list[int] | None = None, residual_activation_name: str = "GELU", residual_dropout_probability: float = 0.05, residual_use_layer_norm: bool = True, low_order_scale: float = 1.0, stable_middle_scale: float = 1.0, high_order_scale: float = 1.0, residual_scale: float = 1.0, readout_position: str = "center", freeze_harmonic_heads: bool = False, ) -> None: """Initialize the embryonic grouped-head Wave 5.1 model. Args: input_size: Input feature count, with angular position in the first feature column. output_size: Regression target count. The skeleton supports the repository's scalar TE target. harmonic_order: Contiguous fallback harmonic order. coefficient_mode: Harmonic branch coefficient mode. low_order_harmonic_index_list: Low-order / offset harmonic group. stable_middle_harmonic_index_list: Stable middle harmonic group. high_order_harmonic_index_list: High-order fragile harmonic group. residual_hidden_size: Residual branch hidden widths. residual_activation_name: Residual branch activation name. residual_dropout_probability: Residual branch dropout. residual_use_layer_norm: Whether the residual branch uses layer normalization. low_order_scale: Scale applied to the low-order harmonic head. stable_middle_scale: Scale applied to the stable middle head. high_order_scale: Scale applied to the high-order harmonic head. residual_scale: Scale applied to the residual correction head. readout_position: Readout position for sequence batches. freeze_harmonic_heads: Whether all harmonic heads are frozen. """ super().__init__() # Resolve Defaults low_order_harmonic_index_list = list(low_order_harmonic_index_list or [0, 1]) stable_middle_harmonic_index_list = list(stable_middle_harmonic_index_list or [3, 39, 40, 78, 81]) high_order_harmonic_index_list = list(high_order_harmonic_index_list or [156, 162, 240]) residual_hidden_size = residual_hidden_size or [96, 64] # Validate Parameters assert input_size >= 4, f"Input Size must expose angle and TE operating features | {input_size}" assert output_size == 1, f"Wave 5.1 grouped heads support scalar TE output only | {output_size}" assert harmonic_order > 0, f"Harmonic Order must be positive | {harmonic_order}" assert low_order_scale >= 0.0, f"Low Order Scale must be non-negative | {low_order_scale}" assert stable_middle_scale >= 0.0, f"Stable Middle Scale must be non-negative | {stable_middle_scale}" assert high_order_scale >= 0.0, f"High Order Scale must be non-negative | {high_order_scale}" assert residual_scale >= 0.0, f"Residual Scale must be non-negative | {residual_scale}" # Save Parameters self.input_size = int(input_size) self.output_size = int(output_size) self.harmonic_order = int(harmonic_order) self.coefficient_mode = str(coefficient_mode) self.low_order_scale = float(low_order_scale) self.stable_middle_scale = float(stable_middle_scale) self.high_order_scale = float(high_order_scale) self.residual_scale = float(residual_scale) self.readout_position = str(readout_position) self.freeze_harmonic_heads = bool(freeze_harmonic_heads) # Register Harmonic Group Buffers For Stable Diagnostics self.register_buffer( "low_order_harmonic_index_tensor", torch.as_tensor(low_order_harmonic_index_list, dtype=torch.long), persistent=True, ) self.register_buffer( "stable_middle_harmonic_index_tensor", torch.as_tensor(stable_middle_harmonic_index_list, dtype=torch.long), persistent=True, ) self.register_buffer( "high_order_harmonic_index_tensor", torch.as_tensor(high_order_harmonic_index_list, dtype=torch.long), persistent=True, ) # Initialize Grouped Harmonic Heads self.low_order_head = HarmonicRegression( input_size=input_size, output_size=output_size, harmonic_order=harmonic_order, coefficient_mode=coefficient_mode, harmonic_index_list=low_order_harmonic_index_list, ) self.stable_middle_head = HarmonicRegression( input_size=input_size, output_size=output_size, harmonic_order=harmonic_order, coefficient_mode=coefficient_mode, harmonic_index_list=stable_middle_harmonic_index_list, ) self.high_order_head = HarmonicRegression( input_size=input_size, output_size=output_size, harmonic_order=harmonic_order, coefficient_mode=coefficient_mode, harmonic_index_list=high_order_harmonic_index_list, ) # Optionally Freeze Harmonic Heads if self.freeze_harmonic_heads: for harmonic_head in [self.low_order_head, self.stable_middle_head, self.high_order_head]: for harmonic_parameter in harmonic_head.parameters(): harmonic_parameter.requires_grad = False # Initialize Residual Shape Head self.residual_branch = FeedForwardNetwork( input_size=input_size, hidden_size=residual_hidden_size, output_size=output_size, activation_name=residual_activation_name, dropout_probability=residual_dropout_probability, use_layer_norm=residual_use_layer_norm, )
[docs] def resolve_readout_tensor(self, input_tensor: torch.Tensor) -> torch.Tensor: """Return point-level features from point or sequence input.""" if input_tensor.ndim == 2: assert input_tensor.shape[-1] == self.input_size, ( f"Input feature mismatch | {input_tensor.shape[-1]} vs {self.input_size}" ) return input_tensor assert input_tensor.ndim == 3, f"Input Tensor must be rank-2 or rank-3 | {tuple(input_tensor.shape)}" assert input_tensor.shape[-1] == self.input_size, ( f"Input feature mismatch | {input_tensor.shape[-1]} vs {self.input_size}" ) return resolve_sequence_readout_tensor(input_tensor, self.readout_position)
[docs] def compute_auxiliary_output_dictionary( self, input_tensor: torch.Tensor, normalized_input_tensor: torch.Tensor, ) -> dict[str, torch.Tensor]: """Expose grouped harmonic, residual, and combined prediction tensors.""" # Resolve Point-Level Views readout_input_tensor = self.resolve_readout_tensor(input_tensor) readout_normalized_input_tensor = self.resolve_readout_tensor(normalized_input_tensor) # Compute Grouped Harmonic Heads low_order_prediction_tensor = self.low_order_head.forward_with_input_context( readout_input_tensor, readout_normalized_input_tensor, ) * self.low_order_scale stable_middle_prediction_tensor = self.stable_middle_head.forward_with_input_context( readout_input_tensor, readout_normalized_input_tensor, ) * self.stable_middle_scale high_order_prediction_tensor = self.high_order_head.forward_with_input_context( readout_input_tensor, readout_normalized_input_tensor, ) * self.high_order_scale # Compute Residual Shape Correction residual_prediction_tensor = self.residual_branch(readout_normalized_input_tensor) * self.residual_scale grouped_harmonic_prediction_tensor = ( low_order_prediction_tensor + stable_middle_prediction_tensor + high_order_prediction_tensor ) prediction_tensor = grouped_harmonic_prediction_tensor + residual_prediction_tensor return { "low_order_prediction_tensor": low_order_prediction_tensor, "stable_middle_prediction_tensor": stable_middle_prediction_tensor, "high_order_prediction_tensor": high_order_prediction_tensor, "grouped_harmonic_prediction_tensor": grouped_harmonic_prediction_tensor, "residual_prediction_tensor": residual_prediction_tensor, "wave3_grouped_residual_prediction_tensor": residual_prediction_tensor, "prediction_tensor": prediction_tensor, }
[docs] def forward_with_input_context(self, input_tensor: torch.Tensor, normalized_input_tensor: torch.Tensor) -> torch.Tensor: """Predict normalized TE using raw angle context and normalized features.""" return self.compute_auxiliary_output_dictionary(input_tensor, normalized_input_tensor)["prediction_tensor"]
[docs] def forward(self, normalized_input_tensor: torch.Tensor) -> torch.Tensor: """Run inference from normalized inputs when raw context is unavailable.""" return self.compute_auxiliary_output_dictionary( normalized_input_tensor, normalized_input_tensor, )["prediction_tensor"]