Source code for scripts.models.model_factory

"""Factory helpers that map TE model-type strings to concrete modules."""

from __future__ import annotations

# Import Typing Utilities
from typing import Any

# Import PyTorch Utilities
import torch.nn as nn

# Import Project Models
from scripts.models.feedforward_network import FeedForwardNetwork
from scripts.models.harmonic_regression import HarmonicRegression
from scripts.models.periodic_feature_network import PeriodicFeatureNetwork
from scripts.models.residual_harmonic_network import ResidualHarmonicNetwork

[docs] def create_model(model_type: str, model_configuration: dict[str, Any]) -> nn.Module: """Instantiate one supported TE model from a configuration dictionary. Args: model_type: Canonical model-type string such as `feedforward`, `harmonic_regression`, `periodic_mlp`, or `residual_harmonic_mlp`. model_configuration: Model-specific configuration dictionary. Returns: nn.Module: Instantiated PyTorch module matching the requested model type. Raises: ValueError: If `model_type` does not match one of the supported model families. """ # Validate Model Type normalized_model_type = model_type.lower() # Create Requested Feedforward Model if normalized_model_type == "feedforward": return FeedForwardNetwork( input_size=int(model_configuration["input_size"]), hidden_size=list(model_configuration["hidden_size"]), output_size=int(model_configuration["output_size"]), activation_name=str(model_configuration["activation_name"]), dropout_probability=float(model_configuration["dropout_probability"]), use_layer_norm=bool(model_configuration["use_layer_norm"]), ) # Create Harmonic Regression Baseline if normalized_model_type == "harmonic_regression": return HarmonicRegression( input_size=int(model_configuration["input_size"]), output_size=int(model_configuration.get("output_size", 1)), harmonic_order=int(model_configuration["harmonic_order"]), coefficient_mode=str(model_configuration.get("coefficient_mode", "static")), ) # Create Periodic-Feature Feedforward Model if normalized_model_type == "periodic_mlp": return PeriodicFeatureNetwork( input_size=int(model_configuration["input_size"]), hidden_size=list(model_configuration["hidden_size"]), output_size=int(model_configuration["output_size"]), activation_name=str(model_configuration["activation_name"]), dropout_probability=float(model_configuration["dropout_probability"]), use_layer_norm=bool(model_configuration["use_layer_norm"]), harmonic_order=int(model_configuration["harmonic_order"]), include_raw_angle_feature=bool(model_configuration.get("include_raw_angle_feature", True)), ) # Create Residual Harmonic + Feedforward Model if normalized_model_type == "residual_harmonic_mlp": return ResidualHarmonicNetwork( input_size=int(model_configuration["input_size"]), output_size=int(model_configuration.get("output_size", 1)), harmonic_order=int(model_configuration["harmonic_order"]), coefficient_mode=str(model_configuration.get("coefficient_mode", "static")), residual_hidden_size=list(model_configuration["residual_hidden_size"]), residual_activation_name=str(model_configuration.get("residual_activation_name", "GELU")), residual_dropout_probability=float(model_configuration.get("residual_dropout_probability", 0.10)), residual_use_layer_norm=bool(model_configuration.get("residual_use_layer_norm", True)), freeze_structured_branch=bool(model_configuration.get("freeze_structured_branch", False)), ) raise ValueError(f"Unsupported Model Type | {model_type}")