"""Factory helpers that map TE model-type strings to concrete modules."""
from __future__ import annotations
# Import Typing Utilities
from typing import Any
# Import PyTorch Utilities
import torch.nn as nn
# Import Project Models
from scripts.models.feedforward_network import FeedForwardNetwork
from scripts.models.harmonic_regression import HarmonicRegression
from scripts.models.harmonic_residual_offset_network import HarmonicResidualOffsetNetwork
from scripts.models.latent_state_hysteresis_network import LatentStateHysteresisNetwork
from scripts.models.periodic_feature_network import PeriodicFeatureNetwork
from scripts.models.periodic_temporal_sequence_network import PeriodicTemporalSequenceNetwork
from scripts.models.residual_harmonic_network import ResidualHarmonicNetwork
from scripts.models.residual_harmonic_temporal_sequence_network import ResidualHarmonicTemporalSequenceNetwork
from scripts.models.sequential_residual_offset_network import SequentialResidualOffsetNetwork
from scripts.models.temporal_sequence_network import RecurrentSequenceNetwork
from scripts.models.temporal_sequence_network import TemporalConvolutionNetwork
from scripts.models.wave3_grouped_harmonic_heads_network import Wave3GroupedHarmonicHeadsNetwork
from scripts.models.wave3_harmonic_prior_residual_network import Wave3HarmonicPriorResidualNetwork
[docs]
def create_model(model_type: str, model_configuration: dict[str, Any]) -> nn.Module:
"""Instantiate one supported TE model from a configuration dictionary.
Args:
model_type: Canonical model-type string such as `feedforward`,
`harmonic_regression`, `periodic_mlp`,
`residual_harmonic_mlp`, `temporal_convolution`,
`gru_sequence`, `lstm_sequence`, or one of the periodic temporal
sequence and residual harmonic temporal variants.
model_configuration: Model-specific configuration dictionary.
Returns:
nn.Module: Instantiated PyTorch module matching the requested model
type.
Raises:
ValueError: If `model_type` does not match one of the supported model
families.
"""
# Validate Model Type
normalized_model_type = model_type.lower()
# Create Requested Feedforward Model
if normalized_model_type == "feedforward":
return FeedForwardNetwork(
input_size=int(model_configuration["input_size"]),
hidden_size=list(model_configuration["hidden_size"]),
output_size=int(model_configuration["output_size"]),
activation_name=str(model_configuration["activation_name"]),
dropout_probability=float(model_configuration["dropout_probability"]),
use_layer_norm=bool(model_configuration["use_layer_norm"]),
)
# Create Harmonic Regression Baseline
if normalized_model_type == "harmonic_regression":
return HarmonicRegression(
input_size=int(model_configuration["input_size"]),
output_size=int(model_configuration.get("output_size", 1)),
harmonic_order=int(model_configuration["harmonic_order"]),
coefficient_mode=str(model_configuration.get("coefficient_mode", "static")),
harmonic_index_list=model_configuration.get("harmonic_index_list"),
)
# Create Periodic-Feature Feedforward Model
if normalized_model_type == "periodic_mlp":
return PeriodicFeatureNetwork(
input_size=int(model_configuration["input_size"]),
hidden_size=list(model_configuration["hidden_size"]),
output_size=int(model_configuration["output_size"]),
activation_name=str(model_configuration["activation_name"]),
dropout_probability=float(model_configuration["dropout_probability"]),
use_layer_norm=bool(model_configuration["use_layer_norm"]),
harmonic_order=int(model_configuration["harmonic_order"]),
harmonic_index_list=model_configuration.get("harmonic_index_list"),
include_raw_angle_feature=bool(model_configuration.get("include_raw_angle_feature", True)),
)
# Create Residual Harmonic + Feedforward Model
if normalized_model_type == "residual_harmonic_mlp":
return ResidualHarmonicNetwork(
input_size=int(model_configuration["input_size"]),
output_size=int(model_configuration.get("output_size", 1)),
harmonic_order=int(model_configuration["harmonic_order"]),
coefficient_mode=str(model_configuration.get("coefficient_mode", "static")),
harmonic_index_list=model_configuration.get("harmonic_index_list"),
residual_hidden_size=list(model_configuration["residual_hidden_size"]),
residual_activation_name=str(model_configuration.get("residual_activation_name", "GELU")),
residual_dropout_probability=float(model_configuration.get("residual_dropout_probability", 0.10)),
residual_use_layer_norm=bool(model_configuration.get("residual_use_layer_norm", True)),
freeze_structured_branch=bool(model_configuration.get("freeze_structured_branch", False)),
)
# Create Temporal Convolution Sequence Model
if normalized_model_type == "temporal_convolution":
return TemporalConvolutionNetwork(
input_size=int(model_configuration["input_size"]),
channel_size=list(model_configuration["channel_size"]),
output_size=int(model_configuration.get("output_size", 1)),
kernel_size=int(model_configuration.get("kernel_size", 5)),
activation_name=str(model_configuration.get("activation_name", "GELU")),
dropout_probability=float(model_configuration.get("dropout_probability", 0.10)),
readout_position=str(model_configuration.get("readout_position", "center")),
)
# Create Periodic Temporal Convolution Sequence Model
if normalized_model_type == "periodic_temporal_convolution":
return PeriodicTemporalSequenceNetwork(
temporal_model_type="temporal_convolution",
input_size=int(model_configuration["input_size"]),
output_size=int(model_configuration.get("output_size", 1)),
harmonic_order=int(model_configuration["harmonic_order"]),
harmonic_index_list=model_configuration.get("harmonic_index_list"),
include_raw_angle_feature=bool(model_configuration.get("include_raw_angle_feature", True)),
channel_size=list(model_configuration["channel_size"]),
kernel_size=int(model_configuration.get("kernel_size", 5)),
activation_name=str(model_configuration.get("activation_name", "GELU")),
dropout_probability=float(model_configuration.get("dropout_probability", 0.10)),
readout_position=str(model_configuration.get("readout_position", "center")),
)
# Create GRU Sequence Model
if normalized_model_type == "gru_sequence":
return RecurrentSequenceNetwork(
recurrent_type="gru",
input_size=int(model_configuration["input_size"]),
hidden_size=int(model_configuration["hidden_size"]),
output_size=int(model_configuration.get("output_size", 1)),
num_layers=int(model_configuration.get("num_layers", 2)),
dropout_probability=float(model_configuration.get("dropout_probability", 0.10)),
bidirectional=bool(model_configuration.get("bidirectional", False)),
readout_position=str(model_configuration.get("readout_position", "center")),
)
# Create Residual Harmonic GRU Sequence Model
if normalized_model_type == "residual_harmonic_gru_sequence":
return ResidualHarmonicTemporalSequenceNetwork(
temporal_model_type="gru_sequence",
input_size=int(model_configuration["input_size"]),
output_size=int(model_configuration.get("output_size", 1)),
harmonic_order=int(model_configuration["harmonic_order"]),
coefficient_mode=str(model_configuration.get("coefficient_mode", "static")),
harmonic_index_list=model_configuration.get("harmonic_index_list"),
hidden_size=int(model_configuration["hidden_size"]),
num_layers=int(model_configuration.get("num_layers", 2)),
dropout_probability=float(model_configuration.get("dropout_probability", 0.10)),
bidirectional=bool(model_configuration.get("bidirectional", False)),
readout_position=str(model_configuration.get("readout_position", "center")),
freeze_structured_branch=bool(model_configuration.get("freeze_structured_branch", False)),
)
# Create Periodic GRU Sequence Model
if normalized_model_type == "periodic_gru_sequence":
return PeriodicTemporalSequenceNetwork(
temporal_model_type="gru_sequence",
input_size=int(model_configuration["input_size"]),
output_size=int(model_configuration.get("output_size", 1)),
harmonic_order=int(model_configuration["harmonic_order"]),
harmonic_index_list=model_configuration.get("harmonic_index_list"),
include_raw_angle_feature=bool(model_configuration.get("include_raw_angle_feature", True)),
hidden_size=int(model_configuration["hidden_size"]),
num_layers=int(model_configuration.get("num_layers", 2)),
dropout_probability=float(model_configuration.get("dropout_probability", 0.10)),
bidirectional=bool(model_configuration.get("bidirectional", False)),
readout_position=str(model_configuration.get("readout_position", "center")),
)
# Create LSTM Sequence Model
if normalized_model_type == "lstm_sequence":
return RecurrentSequenceNetwork(
recurrent_type="lstm",
input_size=int(model_configuration["input_size"]),
hidden_size=int(model_configuration["hidden_size"]),
output_size=int(model_configuration.get("output_size", 1)),
num_layers=int(model_configuration.get("num_layers", 2)),
dropout_probability=float(model_configuration.get("dropout_probability", 0.10)),
bidirectional=bool(model_configuration.get("bidirectional", False)),
readout_position=str(model_configuration.get("readout_position", "center")),
)
# Create Residual Harmonic LSTM Sequence Model
if normalized_model_type == "residual_harmonic_lstm_sequence":
return ResidualHarmonicTemporalSequenceNetwork(
temporal_model_type="lstm_sequence",
input_size=int(model_configuration["input_size"]),
output_size=int(model_configuration.get("output_size", 1)),
harmonic_order=int(model_configuration["harmonic_order"]),
coefficient_mode=str(model_configuration.get("coefficient_mode", "static")),
harmonic_index_list=model_configuration.get("harmonic_index_list"),
hidden_size=int(model_configuration["hidden_size"]),
num_layers=int(model_configuration.get("num_layers", 2)),
dropout_probability=float(model_configuration.get("dropout_probability", 0.10)),
bidirectional=bool(model_configuration.get("bidirectional", False)),
readout_position=str(model_configuration.get("readout_position", "center")),
freeze_structured_branch=bool(model_configuration.get("freeze_structured_branch", False)),
)
# Create Wave 3.1 Sequential Residual-Offset Probe
if normalized_model_type == "sequential_residual_offset_probe":
return SequentialResidualOffsetNetwork(
input_size=int(model_configuration["input_size"]),
output_size=int(model_configuration.get("output_size", 1)),
base_hidden_size=list(model_configuration.get("base_hidden_size", [96, 64])),
base_activation_name=str(model_configuration.get("base_activation_name", "GELU")),
base_dropout_probability=float(model_configuration.get("base_dropout_probability", 0.05)),
base_use_layer_norm=bool(model_configuration.get("base_use_layer_norm", True)),
offset_hidden_size=int(model_configuration.get("offset_hidden_size", 96)),
offset_num_layers=int(model_configuration.get("offset_num_layers", 2)),
offset_dropout_probability=float(model_configuration.get("offset_dropout_probability", 0.10)),
offset_bidirectional=bool(model_configuration.get("offset_bidirectional", False)),
offset_readout_position=str(model_configuration.get("offset_readout_position", "center")),
offset_scale=float(model_configuration.get("offset_scale", 1.0)),
)
# Create Wave 3.2 Harmonic Residual-Offset Probe
if normalized_model_type in ["harmonic_residual_offset_probe", "curve_aware_harmonic_residual_offset_probe"]:
return HarmonicResidualOffsetNetwork(
input_size=int(model_configuration["input_size"]),
output_size=int(model_configuration.get("output_size", 1)),
harmonic_order=int(model_configuration["harmonic_order"]),
coefficient_mode=str(model_configuration.get("coefficient_mode", "linear_conditioned")),
harmonic_index_list=model_configuration.get("harmonic_index_list"),
offset_hidden_size=int(model_configuration.get("offset_hidden_size", 96)),
offset_num_layers=int(model_configuration.get("offset_num_layers", 2)),
offset_dropout_probability=float(model_configuration.get("offset_dropout_probability", 0.10)),
offset_bidirectional=bool(model_configuration.get("offset_bidirectional", False)),
offset_readout_position=str(model_configuration.get("offset_readout_position", "center")),
offset_scale=float(model_configuration.get("offset_scale", 1.0)),
freeze_structured_branch=bool(model_configuration.get("freeze_structured_branch", False)),
)
# Create Wave 4.4 Latent-State Hysteresis Probe
if normalized_model_type == "latent_state_hysteresis_probe":
return LatentStateHysteresisNetwork(
input_size=int(model_configuration["input_size"]),
output_size=int(model_configuration.get("output_size", 1)),
latent_encoder_type=str(model_configuration.get("latent_encoder_type", "gru")),
latent_hidden_size=int(model_configuration.get("latent_hidden_size", 96)),
latent_num_layers=int(model_configuration.get("latent_num_layers", 2)),
latent_dropout_probability=float(model_configuration.get("latent_dropout_probability", 0.10)),
latent_channel_size=model_configuration.get("latent_channel_size"),
latent_kernel_size=int(model_configuration.get("latent_kernel_size", 5)),
latent_activation_name=str(model_configuration.get("latent_activation_name", "GELU")),
readout_position=str(model_configuration.get("readout_position", "last")),
base_hidden_size=list(model_configuration.get("base_hidden_size", [96, 64])),
head_hidden_size=list(model_configuration.get("head_hidden_size", [96, 64])),
head_activation_name=str(model_configuration.get("head_activation_name", "GELU")),
head_dropout_probability=float(model_configuration.get("head_dropout_probability", 0.05)),
use_layer_norm=bool(model_configuration.get("use_layer_norm", True)),
offset_scale=float(model_configuration.get("offset_scale", 1.0)),
residual_scale=float(model_configuration.get("residual_scale", 1.0)),
)
# Create Embryonic Wave 5.1 Harmonic-Prior Residual Skeleton
if normalized_model_type == "wave3_harmonic_prior_residual":
return Wave3HarmonicPriorResidualNetwork(
input_size=int(model_configuration["input_size"]),
output_size=int(model_configuration.get("output_size", 1)),
harmonic_order=int(model_configuration.get("harmonic_order", 240)),
coefficient_mode=str(model_configuration.get("coefficient_mode", "linear_conditioned")),
harmonic_index_list=model_configuration.get("harmonic_index_list"),
residual_hidden_size=list(model_configuration.get("residual_hidden_size", [96, 64])),
residual_activation_name=str(model_configuration.get("residual_activation_name", "GELU")),
residual_dropout_probability=float(model_configuration.get("residual_dropout_probability", 0.05)),
residual_use_layer_norm=bool(model_configuration.get("residual_use_layer_norm", True)),
residual_scale=float(model_configuration.get("residual_scale", 1.0)),
readout_position=str(model_configuration.get("readout_position", "center")),
freeze_structured_branch=bool(model_configuration.get("freeze_structured_branch", False)),
low_order_harmonic_index_list=model_configuration.get("low_order_harmonic_index_list"),
stable_middle_harmonic_index_list=model_configuration.get("stable_middle_harmonic_index_list"),
high_order_harmonic_index_list=model_configuration.get("high_order_harmonic_index_list"),
)
# Create Embryonic Wave 5.1 Grouped Harmonic-Heads Skeleton
if normalized_model_type == "wave3_grouped_harmonic_heads":
return Wave3GroupedHarmonicHeadsNetwork(
input_size=int(model_configuration["input_size"]),
output_size=int(model_configuration.get("output_size", 1)),
harmonic_order=int(model_configuration.get("harmonic_order", 240)),
coefficient_mode=str(model_configuration.get("coefficient_mode", "linear_conditioned")),
low_order_harmonic_index_list=model_configuration.get("low_order_harmonic_index_list"),
stable_middle_harmonic_index_list=model_configuration.get("stable_middle_harmonic_index_list"),
high_order_harmonic_index_list=model_configuration.get("high_order_harmonic_index_list"),
residual_hidden_size=list(model_configuration.get("residual_hidden_size", [96, 64])),
residual_activation_name=str(model_configuration.get("residual_activation_name", "GELU")),
residual_dropout_probability=float(model_configuration.get("residual_dropout_probability", 0.05)),
residual_use_layer_norm=bool(model_configuration.get("residual_use_layer_norm", True)),
low_order_scale=float(model_configuration.get("low_order_scale", 1.0)),
stable_middle_scale=float(model_configuration.get("stable_middle_scale", 1.0)),
high_order_scale=float(model_configuration.get("high_order_scale", 1.0)),
residual_scale=float(model_configuration.get("residual_scale", 1.0)),
readout_position=str(model_configuration.get("readout_position", "center")),
freeze_harmonic_heads=bool(model_configuration.get("freeze_harmonic_heads", False)),
)
# Create Periodic LSTM Sequence Model
if normalized_model_type == "periodic_lstm_sequence":
return PeriodicTemporalSequenceNetwork(
temporal_model_type="lstm_sequence",
input_size=int(model_configuration["input_size"]),
output_size=int(model_configuration.get("output_size", 1)),
harmonic_order=int(model_configuration["harmonic_order"]),
harmonic_index_list=model_configuration.get("harmonic_index_list"),
include_raw_angle_feature=bool(model_configuration.get("include_raw_angle_feature", True)),
hidden_size=int(model_configuration["hidden_size"]),
num_layers=int(model_configuration.get("num_layers", 2)),
dropout_probability=float(model_configuration.get("dropout_probability", 0.10)),
bidirectional=bool(model_configuration.get("bidirectional", False)),
readout_position=str(model_configuration.get("readout_position", "center")),
)
raise ValueError(f"Unsupported Model Type | {model_type}")